You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2015/12/10 21:03:34 UTC

spark git commit: [SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema

Repository: spark
Updated Branches:
  refs/heads/master d9d354ed4 -> bc5f56aa6


[SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema

https://issues.apache.org/jira/browse/SPARK-12250

Author: Yin Huai <yh...@databricks.com>

Closes #10236 from yhuai/SPARK-12250.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bc5f56aa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bc5f56aa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bc5f56aa

Branch: refs/heads/master
Commit: bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f
Parents: d9d354e
Author: Yin Huai <yh...@databricks.com>
Authored: Thu Dec 10 12:03:29 2015 -0800
Committer: Yin Huai <yh...@databricks.com>
Committed: Thu Dec 10 12:03:29 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/execution/aggregate/udaf.scala    |  5 --
 .../hive/execution/AggregationQuerySuite.scala  | 64 ++++++++++++++++++++
 2 files changed, 64 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bc5f56aa/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 20359c1..c0d0010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF(
   override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
     copy(inputAggBufferOffset = newInputAggBufferOffset)
 
-  require(
-    children.length == udaf.inputSchema.length,
-    s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
-      s"but ${children.length} are provided.")
-
   override def nullable: Boolean = true
 
   override def dataType: DataType = udaf.dataType

http://git-wip-us.apache.org/repos/asf/spark/blob/bc5f56aa/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 39c0a2a..064c000 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
   }
 }
 
+class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
+
+  def inputSchema: StructType = StructType(Nil)
+
+  def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil)
+
+  def dataType: DataType = LongType
+
+  def deterministic: Boolean = true
+
+  def initialize(buffer: MutableAggregationBuffer): Unit = {
+    buffer.update(0, 0L)
+  }
+
+  def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+    buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0))
+  }
+
+  def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+    buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
+  }
+
+  def evaluate(buffer: Row): Any = {
+    buffer.getLong(0)
+  }
+}
+
 class LongProductSum extends UserDefinedAggregateFunction {
   def inputSchema: StructType = new StructType()
     .add("a", LongType)
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
       )
     }
   }
+
+  test("udaf without specifying inputSchema") {
+    withTempTable("noInputSchemaUDAF") {
+      sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
+
+      val data =
+        Row(1, Seq(Row(1), Row(2), Row(3))) ::
+          Row(1, Seq(Row(4), Row(5), Row(6))) ::
+          Row(2, Seq(Row(-10))) :: Nil
+      val schema =
+        StructType(
+          StructField("key", IntegerType) ::
+            StructField("myArray",
+              ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
+      sqlContext.createDataFrame(
+        sparkContext.parallelize(data, 2),
+        schema)
+        .registerTempTable("noInputSchemaUDAF")
+
+      checkAnswer(
+        sqlContext.sql(
+          """
+            |SELECT key, noInputSchema(myArray)
+            |FROM noInputSchemaUDAF
+            |GROUP BY key
+          """.stripMargin),
+        Row(1, 21) :: Row(2, -10) :: Nil)
+
+      checkAnswer(
+        sqlContext.sql(
+          """
+            |SELECT noInputSchema(myArray)
+            |FROM noInputSchemaUDAF
+          """.stripMargin),
+        Row(11) :: Nil)
+    }
+  }
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org