You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/12/10 21:03:34 UTC
spark git commit: [SPARK-12250][SQL] Allow users to define a UDAF
without providing details of its inputSchema
Repository: spark
Updated Branches:
refs/heads/master d9d354ed4 -> bc5f56aa6
[SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema
https://issues.apache.org/jira/browse/SPARK-12250
Author: Yin Huai <yh...@databricks.com>
Closes #10236 from yhuai/SPARK-12250.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc5f56aa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc5f56aa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc5f56aa
Branch: refs/heads/master
Commit: bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f
Parents: d9d354e
Author: Yin Huai <yh...@databricks.com>
Authored: Thu Dec 10 12:03:29 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Dec 10 12:03:29 2015 -0800
----------------------------------------------------------------------
.../spark/sql/execution/aggregate/udaf.scala | 5 --
.../hive/execution/AggregationQuerySuite.scala | 64 ++++++++++++++++++++
2 files changed, 64 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/bc5f56aa/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 20359c1..c0d0010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- require(
- children.length == udaf.inputSchema.length,
- s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
- s"but ${children.length} are provided.")
-
override def nullable: Boolean = true
override def dataType: DataType = udaf.dataType
http://git-wip-us.apache.org/repos/asf/spark/blob/bc5f56aa/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 39c0a2a..064c000 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
}
}
+class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = StructType(Nil)
+
+ def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil)
+
+ def dataType: DataType = LongType
+
+ def deterministic: Boolean = true
+
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer.update(0, 0L)
+ }
+
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0))
+ }
+
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
+ }
+
+ def evaluate(buffer: Row): Any = {
+ buffer.getLong(0)
+ }
+}
+
class LongProductSum extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType()
.add("a", LongType)
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
)
}
}
+
+ test("udaf without specifying inputSchema") {
+ withTempTable("noInputSchemaUDAF") {
+ sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
+
+ val data =
+ Row(1, Seq(Row(1), Row(2), Row(3))) ::
+ Row(1, Seq(Row(4), Row(5), Row(6))) ::
+ Row(2, Seq(Row(-10))) :: Nil
+ val schema =
+ StructType(
+ StructField("key", IntegerType) ::
+ StructField("myArray",
+ ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
+ sqlContext.createDataFrame(
+ sparkContext.parallelize(data, 2),
+ schema)
+ .registerTempTable("noInputSchemaUDAF")
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key, noInputSchema(myArray)
+ |FROM noInputSchemaUDAF
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 21) :: Row(2, -10) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT noInputSchema(myArray)
+ |FROM noInputSchemaUDAF
+ """.stripMargin),
+ Row(11) :: Nil)
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org