You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/01 19:30:58 UTC

spark git commit: [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.

Repository: spark
Updated Branches:
  refs/heads/master 69c5dee2f -> 4137f769b


[SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.

This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes.

We should add implicit type casting also in the future.

Author: Reynold Xin <rx...@databricks.com>

Closes #7151 from rxin/expects-input-types and squashes the following commits:

16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4137f769
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4137f769
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4137f769

Branch: refs/heads/master
Commit: 4137f769b84300648ad933b0b3054d69a7316745
Parents: 69c5dee
Author: Reynold Xin <rx...@databricks.com>
Authored: Wed Jul 1 10:30:54 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed Jul 1 10:30:54 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  1 -
 .../catalyst/analysis/HiveTypeCoercion.scala    |  8 +++---
 .../sql/catalyst/expressions/Expression.scala   | 29 +++++++++++++++++---
 .../spark/sql/catalyst/expressions/math.scala   |  6 ++--
 .../spark/sql/catalyst/expressions/misc.scala   |  8 +++---
 .../sql/catalyst/expressions/predicates.scala   |  6 ++--
 .../catalyst/expressions/stringOperations.scala | 10 +++----
 7 files changed, 44 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a069b47..583338d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -26,7 +26,6 @@ import org.apache.spark.sql.types._
  * Throws user facing errors when passed invalid queries that fail to analyze.
  */
 trait CheckAnalysis {
-  self: Analyzer =>
 
   /**
    * Override to provide additional checks for correct analysis.

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index a9d396d..2ab5cb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -45,7 +45,7 @@ object HiveTypeCoercion {
       IfCoercion ::
       Division ::
       PropagateTypes ::
-      AddCastForAutoCastInputTypes ::
+      ImplicitTypeCasts ::
       Nil
 
   // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
@@ -705,13 +705,13 @@ object HiveTypeCoercion {
    * Casts types according to the expected input types for Expressions that have the trait
    * [[AutoCastInputTypes]].
    */
-  object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] {
+  object ImplicitTypeCasts extends Rule[LogicalPlan] {
     def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes =>
-        val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map {
+      case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes =>
+        val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map {
           case (child, actual, expected) =>
             if (actual == expected) child else Cast(child, expected)
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index b5063f3..e18a311 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -266,16 +266,37 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
 }
 
 /**
+ * An trait that gets mixin to define the expected input types of an expression.
+ */
+trait ExpectsInputTypes { self: Expression =>
+
+  /**
+   * Expected input types from child expressions. The i-th position in the returned seq indicates
+   * the type requirement for the i-th child.
+   *
+   * The possible values at each position are:
+   * 1. a specific data type, e.g. LongType, StringType.
+   * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType.
+   * 3. a list of specific data types, e.g. Seq(StringType, BinaryType).
+   */
+  def inputTypes: Seq[Any]
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    // We will do the type checking in `HiveTypeCoercion`, so always returning success here.
+    TypeCheckResult.TypeCheckSuccess
+  }
+}
+
+/**
  * Expressions that require a specific `DataType` as input should implement this trait
  * so that the proper type conversions can be performed in the analyzer.
  */
-trait AutoCastInputTypes {
-  self: Expression =>
+trait AutoCastInputTypes { self: Expression =>
 
-  def expectedChildTypes: Seq[DataType]
+  def inputTypes: Seq[DataType]
 
   override def checkInputDataTypes(): TypeCheckResult = {
-    // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
+    // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`,
     // so type mismatch error won't be reported here, but for underling `Cast`s.
     TypeCheckResult.TypeCheckSuccess
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index da63f2f..b51318d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
   extends UnaryExpression with Serializable with AutoCastInputTypes {
   self: Product =>
 
-  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
+  override def inputTypes: Seq[DataType] = Seq(DoubleType)
   override def dataType: DataType = DoubleType
   override def nullable: Boolean = true
   override def toString: String = s"$name($child)"
@@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
 abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
   extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product =>
 
-  override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
+  override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
 
   override def toString: String = s"$name($left, $right)"
 
@@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia
 case class Bin(child: Expression)
   extends UnaryExpression with Serializable with AutoCastInputTypes {
 
-  override def expectedChildTypes: Seq[DataType] = Seq(LongType)
+  override def inputTypes: Seq[DataType] = Seq(LongType)
   override def dataType: DataType = StringType
 
   override def eval(input: InternalRow): Any = {

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index a7bcbe4..407023e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -36,7 +36,7 @@ case class Md5(child: Expression)
 
   override def dataType: DataType = StringType
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+  override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
   override def eval(input: InternalRow): Any = {
     val value = child.eval(input)
@@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression)
 
   override def toString: String = s"SHA2($left, $right)"
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
+  override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType)
 
   override def eval(input: InternalRow): Any = {
     val evalE1 = left.eval(input)
@@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp
 
   override def dataType: DataType = StringType
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+  override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
   override def eval(input: InternalRow): Any = {
     val value = child.eval(input)
@@ -179,7 +179,7 @@ case class Crc32(child: Expression)
 
   override def dataType: DataType = LongType
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BinaryType)
+  override def inputTypes: Seq[DataType] = Seq(BinaryType)
 
   override def eval(input: InternalRow): Any = {
     val value = child.eval(input)

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 98cd5aa..a777f77 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -72,7 +72,7 @@ trait PredicateHelper {
 case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes {
   override def toString: String = s"NOT $child"
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType)
+  override def inputTypes: Seq[DataType] = Seq(BooleanType)
 
   override def eval(input: InternalRow): Any = {
     child.eval(input) match {
@@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any])
 case class And(left: Expression, right: Expression)
   extends BinaryExpression with Predicate with AutoCastInputTypes {
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+  override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
 
   override def symbol: String = "&&"
 
@@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression)
 case class Or(left: Expression, right: Expression)
   extends BinaryExpression with Predicate with AutoCastInputTypes {
 
-  override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
+  override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType)
 
   override def symbol: String = "||"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4137f769/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index ce184e4..4cbfc4e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes {
 
   override def nullable: Boolean = left.nullable || right.nullable
   override def dataType: DataType = BooleanType
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
   // try cache the pattern for Literal
   private lazy val cache: Pattern = right match {
@@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes {
   def convert(v: UTF8String): UTF8String
 
   override def dataType: DataType = StringType
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+  override def inputTypes: Seq[DataType] = Seq(StringType)
 
   override def eval(input: InternalRow): Any = {
     val evaluated = child.eval(input)
@@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes {
 
   override def nullable: Boolean = left.nullable || right.nullable
 
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType)
+  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
 
   override def eval(input: InternalRow): Any = {
     val leftEval = left.eval(input)
@@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
     if (str.dataType == BinaryType) str.dataType else StringType
   }
 
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
+  override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
 
   override def children: Seq[Expression] = str :: pos :: len :: Nil
 
@@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression)
  */
 case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes {
   override def dataType: DataType = IntegerType
-  override def expectedChildTypes: Seq[DataType] = Seq(StringType)
+  override def inputTypes: Seq[DataType] = Seq(StringType)
 
   override def eval(input: InternalRow): Any = {
     val string = child.eval(input)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org