You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2016/11/30 04:05:19 UTC

spark git commit: [SPARK-18632][SQL] AggregateFunction should not implement ImplicitCastInputTypes

Repository: spark
Updated Branches:
  refs/heads/master 9b670bcae -> af9789a4f


[SPARK-18632][SQL] AggregateFunction should not implement ImplicitCastInputTypes

## What changes were proposed in this pull request?
`AggregateFunction` currently implements `ImplicitCastInputTypes` (which enables implicit input type casting). There are actually quite a few situations in which we don't need this, or require more control over our input. A recent example is the aggregate for `CountMinSketch` which should only take string, binary or integral types inputs.

This PR removes `ImplicitCastInputTypes` from the `AggregateFunction` and makes a case-by-case decision on what kind of input validation we should use.

## How was this patch tested?
Refactoring only. Existing tests.

Author: Herman van Hovell <hv...@databricks.com>

Closes #16066 from hvanhovell/SPARK-18632.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/af9789a4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/af9789a4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/af9789a4

Branch: refs/heads/master
Commit: af9789a4f5d00b3141f102e9f0ca52217e26c082
Parents: 9b670bc
Author: Herman van Hovell <hv...@databricks.com>
Authored: Tue Nov 29 20:05:15 2016 -0800
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Nov 29 20:05:15 2016 -0800

----------------------------------------------------------------------
 .../aggregate/ApproximatePercentile.scala       |  7 +++---
 .../expressions/aggregate/Average.scala         |  2 +-
 .../aggregate/CentralMomentAgg.scala            |  3 ++-
 .../catalyst/expressions/aggregate/Corr.scala   |  3 ++-
 .../catalyst/expressions/aggregate/Count.scala  |  3 ---
 .../aggregate/CountMinSketchAgg.scala           |  5 ++--
 .../expressions/aggregate/Covariance.scala      |  3 ++-
 .../catalyst/expressions/aggregate/First.scala  | 26 ++++++++++++++------
 .../aggregate/HyperLogLogPlusPlus.scala         |  2 --
 .../catalyst/expressions/aggregate/Last.scala   | 26 ++++++++++++++------
 .../catalyst/expressions/aggregate/Max.scala    |  3 ---
 .../catalyst/expressions/aggregate/Min.scala    |  3 ---
 .../expressions/aggregate/Percentile.scala      |  9 ++++---
 .../expressions/aggregate/PivotFirst.scala      |  2 --
 .../catalyst/expressions/aggregate/Sum.scala    |  2 +-
 .../expressions/aggregate/collect.scala         |  2 --
 .../expressions/aggregate/interfaces.scala      |  2 +-
 .../expressions/windowExpressions.scala         |  2 --
 .../aggregate/TypedAggregateExpression.scala    |  2 --
 .../spark/sql/execution/aggregate/udaf.scala    |  2 +-
 .../spark/sql/CountMinSketchAggQuerySuite.scala |  8 +++---
 .../sql/TypedImperativeAggregateSuite.scala     |  7 +++---
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  4 ---
 .../sql/hive/execution/TestingTypedCount.scala  |  2 --
 24 files changed, 67 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
index 692cbd7..c2cd895 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -22,11 +22,11 @@ import java.nio.ByteBuffer
 import com.google.common.primitives.{Doubles, Ints, Longs}
 
 import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
 import org.apache.spark.sql.catalyst.util.QuantileSummaries
 import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
@@ -71,7 +71,8 @@ case class ApproximatePercentile(
     percentageExpression: Expression,
     accuracyExpression: Expression,
     override val mutableAggBufferOffset: Int,
-    override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
+    override val inputAggBufferOffset: Int)
+  extends TypedImperativeAggregate[PercentileDigest] with ImplicitCastInputTypes {
 
   def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
     this(child, percentageExpression, accuracyExpression, 0, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index d523420..c423e17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the mean calculated from values of a group.")
-case class Average(child: Expression) extends DeclarativeAggregate {
+case class Average(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   override def prettyName: String = "avg"
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 1a93f45..572d29c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -42,7 +42,8 @@ import org.apache.spark.sql.types._
  *
  * @param child to compute central moments of.
  */
-abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate {
+abstract class CentralMomentAgg(child: Expression)
+  extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   /**
    * The central moment order to be computed.

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
index 657f519..95a4a0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala
@@ -32,7 +32,8 @@ import org.apache.spark.sql.types._
 @ExpressionDescription(
   usage = "_FUNC_(expr1, expr2) - Returns Pearson coefficient of correlation between a set of number pairs.")
 // scalastyle:on line.size.limit
-case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate {
+case class Corr(x: Expression, y: Expression)
+  extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = Seq(x, y)
   override def nullable: Boolean = true

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
index bcae0dc..1990f2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala
@@ -38,9 +38,6 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = LongType
 
-  // Expected input data type.
-  override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType)
-
   private lazy val count = AttributeReference("count", LongType, nullable = false)()
 
   override lazy val aggBufferAttributes = count :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
index 1bfae9e..f5f185f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala
@@ -22,7 +22,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.sketch.CountMinSketch
@@ -52,7 +52,8 @@ case class CountMinSketchAgg(
     confidenceExpression: Expression,
     seedExpression: Expression,
     override val mutableAggBufferOffset: Int,
-    override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[CountMinSketch] {
+    override val inputAggBufferOffset: Int)
+  extends TypedImperativeAggregate[CountMinSketch] with ExpectsInputTypes {
 
   def this(
       child: Expression,

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
index ae5ed77..fc6c34b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types._
  * Compute the covariance between two expressions.
  * When applied on empty data (i.e., count is zero), it returns NULL.
  */
-abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggregate {
+abstract class Covariance(x: Expression, y: Expression)
+  extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = Seq(x, y)
   override def nullable: Boolean = true

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
index 29b8947..bfc58c2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
     _FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows.
       If `isIgnoreNull` is true, returns only non-null values.
   """)
-case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+case class First(child: Expression, ignoreNullsExpr: Expression)
+  extends DeclarativeAggregate with ExpectsInputTypes {
 
   def this(child: Expression) = this(child, Literal.create(false, BooleanType))
 
-  private val ignoreNulls: Boolean = ignoreNullsExpr match {
-    case Literal(b: Boolean, BooleanType) => b
-    case _ =>
-      throw new AnalysisException("The second argument of First should be a boolean literal.")
-  }
-
   override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
 
   override def nullable: Boolean = true
@@ -56,6 +52,20 @@ case class First(child: Expression, ignoreNullsExpr: Expression) extends Declara
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!ignoreNullsExpr.foldable) {
+      TypeCheckFailure(
+        s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]
+
   private lazy val first = AttributeReference("first", child.dataType)()
 
   private lazy val valueSet = AttributeReference("valueSet", BooleanType)()

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index 77b7eb2..d5c9166 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -140,8 +140,6 @@ case class HyperLogLogPlusPlus(
 
   override def dataType: DataType = LongType
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
   override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes)
 
   /** Allocate enough words to store all registers. */

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
index b0a363e..96a6ec0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.types._
 
@@ -33,16 +34,11 @@ import org.apache.spark.sql.types._
     _FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows.
       If `isIgnoreNull` is true, returns only non-null values.
   """)
-case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate {
+case class Last(child: Expression, ignoreNullsExpr: Expression)
+  extends DeclarativeAggregate with ExpectsInputTypes {
 
   def this(child: Expression) = this(child, Literal.create(false, BooleanType))
 
-  private val ignoreNulls: Boolean = ignoreNullsExpr match {
-    case Literal(b: Boolean, BooleanType) => b
-    case _ =>
-      throw new AnalysisException("The second argument of First should be a boolean literal.")
-  }
-
   override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
 
   override def nullable: Boolean = true
@@ -56,6 +52,20 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) extends Declarat
   // Expected input data type.
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, BooleanType)
 
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val defaultCheck = super.checkInputDataTypes()
+    if (defaultCheck.isFailure) {
+      defaultCheck
+    } else if (!ignoreNullsExpr.foldable) {
+      TypeCheckFailure(
+        s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
+    } else {
+      TypeCheckSuccess
+    }
+  }
+
+  private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]
+
   private lazy val last = AttributeReference("last", child.dataType)()
 
   private lazy val valueSet = AttributeReference("valueSet", BooleanType)()

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
index f32c9c6..58fd1d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala
@@ -33,9 +33,6 @@ case class Max(child: Expression) extends DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = child.dataType
 
-  // Expected input data type.
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForOrderingExpr(child.dataType, "function max")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
index 9ef42b9..b2724ee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala
@@ -33,9 +33,6 @@ case class Min(child: Expression) extends DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = child.dataType
 
-  // Expected input data type.
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForOrderingExpr(child.dataType, "function min")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 356e088..b51b553 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -54,10 +54,11 @@ import org.apache.spark.util.collection.OpenHashMap
       be between 0.0 and 1.0.
     """)
 case class Percentile(
-  child: Expression,
-  percentageExpression: Expression,
-  mutableAggBufferOffset: Int = 0,
-  inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+    child: Expression,
+    percentageExpression: Expression,
+    mutableAggBufferOffset: Int = 0,
+    inputAggBufferOffset: Int = 0)
+  extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
 
   def this(child: Expression, percentageExpression: Expression) = {
     this(child, percentageExpression, 0, 0)

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
index 0876060..9ad3124 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -77,8 +77,6 @@ case class PivotFirst(
 
   override val children: Seq[Expression] = pivotColumn :: valueColumn :: Nil
 
-  override lazy val inputTypes: Seq[AbstractDataType] = children.map(_.dataType)
-
   override val nullable: Boolean = false
 
   val valueDataType = valueColumn.dataType

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index f3731d4..96e8cee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.types._
 
 @ExpressionDescription(
   usage = "_FUNC_(expr) - Returns the sum calculated from values of a group.")
-case class Sum(child: Expression) extends DeclarativeAggregate {
+case class Sum(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes {
 
   override def children: Seq[Expression] = child :: Nil
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
index d2880d5..b176e2a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala
@@ -44,8 +44,6 @@ abstract class Collect extends ImperativeAggregate {
 
   override def dataType: DataType = ArrayType(child.dataType)
 
-  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
-
   override def supportsPartial: Boolean = false
 
   override def aggBufferAttributes: Seq[AttributeReference] = Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
index f3fd58b..7397b60 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala
@@ -155,7 +155,7 @@ case class AggregateExpression(
  * Code which accepts [[AggregateFunction]] instances should be prepared to handle both types of
  * aggregate functions.
  */
-sealed abstract class AggregateFunction extends Expression with ImplicitCastInputTypes {
+sealed abstract class AggregateFunction extends Expression {
 
   /** An aggregate function is not foldable. */
   final override def foldable: Boolean = false

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 3cbbcdf..c0d6a6b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -443,7 +443,6 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF
 
 abstract class RowNumberLike extends AggregateWindowFunction {
   override def children: Seq[Expression] = Nil
-  override def inputTypes: Seq[AbstractDataType] = Nil
   protected val zero = Literal(0)
   protected val one = Literal(1)
   protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)()
@@ -600,7 +599,6 @@ case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindow
  * This documentation has been based upon similar documentation for the Hive and Presto projects.
  */
 abstract class RankLike extends AggregateWindowFunction {
-  override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
 
   /** Store the values of the window 'order' expressions. */
   protected val orderAttrs = children.map { expr =>

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 6f7f2f8..9911c0b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -81,8 +81,6 @@ case class TypedAggregateExpression(
 
   override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq)
 
-  override def inputTypes: Seq[AbstractDataType] = Nil
-
   private def aggregatorLiteral =
     Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]]))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 67760f3..ae5e2c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -324,7 +324,7 @@ case class ScalaUDAF(
     udaf: UserDefinedAggregateFunction,
     mutableAggBufferOffset: Int = 0,
     inputAggBufferOffset: Int = 0)
-  extends ImperativeAggregate with NonSQLExpression with Logging {
+  extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes {
 
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
index 4cc5060..3e715a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CountMinSketchAggQuerySuite.scala
@@ -110,9 +110,11 @@ class CountMinSketchAggQuerySuite extends QueryTest with SharedSQLContext {
     withTempView(table) {
       val rdd: RDD[Row] = spark.sparkContext.parallelize(data)
       spark.createDataFrame(rdd, schema).createOrReplaceTempView(table)
-      val cmsSql = schema.fieldNames.map(col => s"count_min_sketch($col, $eps, $confidence, $seed)")
-        .mkString(", ")
-      val result = sql(s"SELECT $cmsSql FROM $table").head()
+
+      val cmsSql = schema.fieldNames.map { col =>
+        s"count_min_sketch($col, ${eps}D, ${confidence}D, $seed)"
+      }
+      val result = sql(s"SELECT ${cmsSql.mkString(", ")} FROM $table").head()
       schema.indices.foreach { i =>
         val binaryData = result.getAs[Array[Byte]](i)
         val in = new ByteArrayInputStream(binaryData)

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
index 0759915..70c3951 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala
@@ -21,13 +21,13 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da
 
 import org.apache.spark.sql.TypedImperativeAggregateSuite.TypedMax
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, SpecificInternalRow}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, GenericInternalRow, ImplicitCastInputTypes, SpecificInternalRow}
 import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, IntegerType, LongType}
+import org.apache.spark.sql.types._
 
 class TypedImperativeAggregateSuite extends QueryTest with SharedSQLContext {
 
@@ -231,7 +231,8 @@ object TypedImperativeAggregateSuite {
       child: Expression,
       nullable: Boolean = false,
       mutableAggBufferOffset: Int = 0,
-      inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[MaxValue] {
+      inputAggBufferOffset: Int = 0)
+    extends TypedImperativeAggregate[MaxValue] with ImplicitCastInputTypes {
 
 
     override def createAggregationBuffer(): MaxValue = {

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 90e8695..349faae 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -378,10 +378,6 @@ private[hive] case class HiveUDAFFunction(
   @transient
   private lazy val aggBufferSerDe: AggregationBufferSerDe = new AggregationBufferSerDe
 
-  // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our
-  // catalyst type checking framework.
-  override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType)
-
   override def nullable: Boolean = true
 
   override def supportsPartial: Boolean = true

http://git-wip-us.apache.org/repos/asf/spark/blob/af9789a4/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
index a3d48d9..d27287b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/TestingTypedCount.scala
@@ -71,8 +71,6 @@ case class TestingTypedCount(
     TestingTypedCount.State(dataStream.readLong())
   }
 
-  override def inputTypes: Seq[AbstractDataType] = AnyDataType :: Nil
-
   override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
     copy(mutableAggBufferOffset = newMutableAggBufferOffset)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org