You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/07/05 17:59:12 UTC
spark git commit: [SPARK-19439][PYSPARK][SQL] PySpark's
registerJavaFunction Should Support UDAFs
Repository: spark
Updated Branches:
refs/heads/master 960298ee6 -> 742da0868
[SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs
## What changes were proposed in this pull request?
Support register Java UDAFs in PySpark so that user can use Java UDAF in PySpark. Besides that I also add api in `UDFRegistration`
## How was this patch tested?
Unit test is added
Author: Jeff Zhang <zj...@apache.org>
Closes #17222 from zjffdu/SPARK-19439.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/742da086
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/742da086
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/742da086
Branch: refs/heads/master
Commit: 742da0868534dab3d4d7b7edbe5ba9dc8bf26cc8
Parents: 960298e
Author: Jeff Zhang <zj...@apache.org>
Authored: Wed Jul 5 10:59:10 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Wed Jul 5 10:59:10 2017 -0700
----------------------------------------------------------------------
python/pyspark/sql/context.py | 23 ++++
python/pyspark/sql/tests.py | 10 ++
.../org/apache/spark/sql/UDFRegistration.scala | 33 ++++-
.../org/apache/spark/sql/JavaUDAFSuite.java | 55 ++++++++
.../test/org/apache/spark/sql/MyDoubleAvg.java | 129 +++++++++++++++++++
.../test/org/apache/spark/sql/MyDoubleSum.java | 118 +++++++++++++++++
sql/hive/pom.xml | 7 +
.../spark/sql/hive/JavaDataFrameSuite.java | 2 +-
.../spark/sql/hive/aggregate/MyDoubleAvg.java | 129 -------------------
.../spark/sql/hive/aggregate/MyDoubleSum.java | 118 -----------------
.../hive/execution/AggregationQuerySuite.scala | 5 +-
11 files changed, 374 insertions(+), 255 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/python/pyspark/sql/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 426f07c..c44ab24 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -232,6 +232,23 @@ class SQLContext(object):
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)
+ @ignore_unicode_prefix
+ @since(2.3)
+ def registerJavaUDAF(self, name, javaClassName):
+ """Register a java UDAF so it can be used in SQL statements.
+
+ :param name: name of the UDAF
+ :param javaClassName: fully qualified name of java class
+
+ >>> sqlContext.registerJavaUDAF("javaUDAF",
+ ... "test.org.apache.spark.sql.MyDoubleAvg")
+ >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
+ >>> df.registerTempTable("df")
+ >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
+ [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
+ """
+ self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)
+
# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
@@ -551,6 +568,12 @@ class UDFRegistration(object):
def register(self, name, f, returnType=StringType()):
return self.sqlContext.registerFunction(name, f, returnType)
+ def registerJavaFunction(self, name, javaClassName, returnType=None):
+ self.sqlContext.registerJavaFunction(name, javaClassName, returnType)
+
+ def registerJavaUDAF(self, name, javaClassName):
+ self.sqlContext.registerJavaUDAF(name, javaClassName)
+
register.__doc__ = SQLContext.registerFunction.__doc__
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 16ba8bd..c0e3b8d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -481,6 +481,16 @@ class SQLTests(ReusedPySparkTestCase):
df.select(add_three("id").alias("plus_three")).collect()
)
+ def test_non_existed_udf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
+ lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))
+
+ def test_non_existed_udaf(self):
+ spark = self.spark
+ self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
+ lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
+
def test_multiLine_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index ad01b88..8bdc022 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql
-import java.io.IOException
import java.lang.reflect.{ParameterizedType, Type}
import scala.reflect.runtime.universe.TypeTag
@@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
.map(_.asInstanceOf[ParameterizedType])
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
if (udfInterfaces.length == 0) {
- throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
+ throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface")
} else if (udfInterfaces.length > 1) {
- throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
+ throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
} else {
try {
val udf = clazz.newInstance()
@@ -491,20 +490,42 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
- case n => logError(s"UDF class with ${n} type arguments is not supported ")
+ case n =>
+ throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.")
}
} catch {
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
- logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
+ throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
}
}
} catch {
- case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
+ case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
}
}
/**
+ * Register a Java UDAF class using reflection, for use from pyspark
+ *
+ * @param name UDAF name
+ * @param className fully qualified class name of UDAF
+ */
+ private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
+ try {
+ val clazz = Utils.classForName(className)
+ if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
+ throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction")
+ }
+ val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]
+ register(name, udaf)
+ } catch {
+ case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
+ case e @ (_: InstantiationException | _: IllegalArgumentException) =>
+ throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
+ }
+ }
+
+ /**
* Register a user-defined function with 1 arguments.
* @since 1.3.0
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
new file mode 100644
index 0000000..ddbaa45
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+
+public class JavaUDAFSuite {
+
+ private transient SparkSession spark;
+
+ @Before
+ public void setUp() {
+ spark = SparkSession.builder()
+ .master("local[*]")
+ .appName("testing")
+ .getOrCreate();
+ }
+
+ @After
+ public void tearDown() {
+ spark.stop();
+ spark = null;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void udf1Test() {
+ spark.range(1, 10).toDF("value").registerTempTable("df");
+ spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName());
+ Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head();
+ Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
new file mode 100644
index 0000000..447a71d
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
+ * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
+ * of the average value of input values and 100.0.
+ */
+public class MyDoubleAvg extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleAvg() {
+ List<StructField> inputFields = new ArrayList<>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
+
+ // The buffer has two values, bufferSum for storing the current sum and
+ // bufferCount for storing the number of non-null input values that have been contribuetd
+ // to the current sum.
+ List<StructField> bufferFields = new ArrayList<>();
+ bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
+ bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType dataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
+ buffer.update(0, null);
+ // The initial value of the count is 0.
+ buffer.update(1, 0L);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
+ if (!input.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer and set the bufferCount to 1.
+ if (buffer.isNullAt(0)) {
+ buffer.update(0, input.getDouble(0));
+ buffer.update(1, 1L);
+ } else {
+ // Otherwise, update the bufferSum and increment bufferCount.
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ buffer.update(1, buffer.getLong(1) + 1L);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's sum value is not null.
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
+ buffer1.update(0, buffer2.getDouble(0));
+ buffer1.update(1, buffer2.getLong(1));
+ } else {
+ // Otherwise, we update the bufferSum and bufferCount.
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ // If the bufferSum is still null, we return null because this function has not got
+ // any input row.
+ return null;
+ } else {
+ // Otherwise, we calculate the special average value.
+ return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java
----------------------------------------------------------------------
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java
new file mode 100644
index 0000000..93d2033
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.expressions.MutableAggregationBuffer;
+import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
+ * {@link org.apache.spark.sql.types.DoubleType} column.
+ */
+public class MyDoubleSum extends UserDefinedAggregateFunction {
+
+ private StructType _inputDataType;
+
+ private StructType _bufferSchema;
+
+ private DataType _returnDataType;
+
+ public MyDoubleSum() {
+ List<StructField> inputFields = new ArrayList<>();
+ inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
+ _inputDataType = DataTypes.createStructType(inputFields);
+
+ List<StructField> bufferFields = new ArrayList<>();
+ bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
+ _bufferSchema = DataTypes.createStructType(bufferFields);
+
+ _returnDataType = DataTypes.DoubleType;
+ }
+
+ @Override public StructType inputSchema() {
+ return _inputDataType;
+ }
+
+ @Override public StructType bufferSchema() {
+ return _bufferSchema;
+ }
+
+ @Override public DataType dataType() {
+ return _returnDataType;
+ }
+
+ @Override public boolean deterministic() {
+ return true;
+ }
+
+ @Override public void initialize(MutableAggregationBuffer buffer) {
+ // The initial value of the sum is null.
+ buffer.update(0, null);
+ }
+
+ @Override public void update(MutableAggregationBuffer buffer, Row input) {
+ // This input Row only has a single column storing the input value in Double.
+ // We only update the buffer when the input value is not null.
+ if (!input.isNullAt(0)) {
+ if (buffer.isNullAt(0)) {
+ // If the buffer value (the intermediate result of the sum) is still null,
+ // we set the input value to the buffer.
+ buffer.update(0, input.getDouble(0));
+ } else {
+ // Otherwise, we add the input value to the buffer value.
+ Double newValue = input.getDouble(0) + buffer.getDouble(0);
+ buffer.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
+ // buffer1 and buffer2 have the same structure.
+ // We only update the buffer1 when the input buffer2's value is not null.
+ if (!buffer2.isNullAt(0)) {
+ if (buffer1.isNullAt(0)) {
+ // If the buffer value (intermediate result of the sum) is still null,
+ // we set the it as the input buffer's value.
+ buffer1.update(0, buffer2.getDouble(0));
+ } else {
+ // Otherwise, we add the input buffer's value (buffer1) to the mutable
+ // buffer's value (buffer2).
+ Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
+ buffer1.update(0, newValue);
+ }
+ }
+ }
+
+ @Override public Object evaluate(Row buffer) {
+ if (buffer.isNullAt(0)) {
+ // If the buffer value is still null, we return null.
+ return null;
+ } else {
+ // Otherwise, the intermediate sum is the final result.
+ return buffer.getDouble(0);
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/pom.xml
----------------------------------------------------------------------
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 09dcc40..f9462e7 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -59,6 +59,13 @@
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
<type>test-jar</type>
<scope>test</scope>
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
index aefc9cc..636ce10 100644
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java
@@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.hive.test.TestHive$;
-import org.apache.spark.sql.hive.aggregate.MyDoubleSum;
+import test.org.apache.spark.sql.MyDoubleSum;
public class JavaDataFrameSuite {
private transient SQLContext hc;
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
deleted file mode 100644
index ae0c097..0000000
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.aggregate;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.expressions.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * An example {@link UserDefinedAggregateFunction} to calculate a special average value of a
- * {@link org.apache.spark.sql.types.DoubleType} column. This special average value is the sum
- * of the average value of input values and 100.0.
- */
-public class MyDoubleAvg extends UserDefinedAggregateFunction {
-
- private StructType _inputDataType;
-
- private StructType _bufferSchema;
-
- private DataType _returnDataType;
-
- public MyDoubleAvg() {
- List<StructField> inputFields = new ArrayList<>();
- inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputFields);
-
- // The buffer has two values, bufferSum for storing the current sum and
- // bufferCount for storing the number of non-null input values that have been contribuetd
- // to the current sum.
- List<StructField> bufferFields = new ArrayList<>();
- bufferFields.add(DataTypes.createStructField("bufferSum", DataTypes.DoubleType, true));
- bufferFields.add(DataTypes.createStructField("bufferCount", DataTypes.LongType, true));
- _bufferSchema = DataTypes.createStructType(bufferFields);
-
- _returnDataType = DataTypes.DoubleType;
- }
-
- @Override public StructType inputSchema() {
- return _inputDataType;
- }
-
- @Override public StructType bufferSchema() {
- return _bufferSchema;
- }
-
- @Override public DataType dataType() {
- return _returnDataType;
- }
-
- @Override public boolean deterministic() {
- return true;
- }
-
- @Override public void initialize(MutableAggregationBuffer buffer) {
- // The initial value of the sum is null.
- buffer.update(0, null);
- // The initial value of the count is 0.
- buffer.update(1, 0L);
- }
-
- @Override public void update(MutableAggregationBuffer buffer, Row input) {
- // This input Row only has a single column storing the input value in Double.
- // We only update the buffer when the input value is not null.
- if (!input.isNullAt(0)) {
- // If the buffer value (the intermediate result of the sum) is still null,
- // we set the input value to the buffer and set the bufferCount to 1.
- if (buffer.isNullAt(0)) {
- buffer.update(0, input.getDouble(0));
- buffer.update(1, 1L);
- } else {
- // Otherwise, update the bufferSum and increment bufferCount.
- Double newValue = input.getDouble(0) + buffer.getDouble(0);
- buffer.update(0, newValue);
- buffer.update(1, buffer.getLong(1) + 1L);
- }
- }
- }
-
- @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
- // buffer1 and buffer2 have the same structure.
- // We only update the buffer1 when the input buffer2's sum value is not null.
- if (!buffer2.isNullAt(0)) {
- if (buffer1.isNullAt(0)) {
- // If the buffer value (intermediate result of the sum) is still null,
- // we set the it as the input buffer's value.
- buffer1.update(0, buffer2.getDouble(0));
- buffer1.update(1, buffer2.getLong(1));
- } else {
- // Otherwise, we update the bufferSum and bufferCount.
- Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
- buffer1.update(0, newValue);
- buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1));
- }
- }
- }
-
- @Override public Object evaluate(Row buffer) {
- if (buffer.isNullAt(0)) {
- // If the bufferSum is still null, we return null because this function has not got
- // any input row.
- return null;
- } else {
- // Otherwise, we calculate the special average value.
- return buffer.getDouble(0) / buffer.getLong(1) + 100.0;
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
deleted file mode 100644
index d17fb3e..0000000
--- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java
+++ /dev/null
@@ -1,118 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.aggregate;
-
-import java.util.ArrayList;
-import java.util.List;
-
-import org.apache.spark.sql.expressions.MutableAggregationBuffer;
-import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.DataTypes;
-import org.apache.spark.sql.Row;
-
-/**
- * An example {@link UserDefinedAggregateFunction} to calculate the sum of a
- * {@link org.apache.spark.sql.types.DoubleType} column.
- */
-public class MyDoubleSum extends UserDefinedAggregateFunction {
-
- private StructType _inputDataType;
-
- private StructType _bufferSchema;
-
- private DataType _returnDataType;
-
- public MyDoubleSum() {
- List<StructField> inputFields = new ArrayList<>();
- inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
- _inputDataType = DataTypes.createStructType(inputFields);
-
- List<StructField> bufferFields = new ArrayList<>();
- bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
- _bufferSchema = DataTypes.createStructType(bufferFields);
-
- _returnDataType = DataTypes.DoubleType;
- }
-
- @Override public StructType inputSchema() {
- return _inputDataType;
- }
-
- @Override public StructType bufferSchema() {
- return _bufferSchema;
- }
-
- @Override public DataType dataType() {
- return _returnDataType;
- }
-
- @Override public boolean deterministic() {
- return true;
- }
-
- @Override public void initialize(MutableAggregationBuffer buffer) {
- // The initial value of the sum is null.
- buffer.update(0, null);
- }
-
- @Override public void update(MutableAggregationBuffer buffer, Row input) {
- // This input Row only has a single column storing the input value in Double.
- // We only update the buffer when the input value is not null.
- if (!input.isNullAt(0)) {
- if (buffer.isNullAt(0)) {
- // If the buffer value (the intermediate result of the sum) is still null,
- // we set the input value to the buffer.
- buffer.update(0, input.getDouble(0));
- } else {
- // Otherwise, we add the input value to the buffer value.
- Double newValue = input.getDouble(0) + buffer.getDouble(0);
- buffer.update(0, newValue);
- }
- }
- }
-
- @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
- // buffer1 and buffer2 have the same structure.
- // We only update the buffer1 when the input buffer2's value is not null.
- if (!buffer2.isNullAt(0)) {
- if (buffer1.isNullAt(0)) {
- // If the buffer value (intermediate result of the sum) is still null,
- // we set the it as the input buffer's value.
- buffer1.update(0, buffer2.getDouble(0));
- } else {
- // Otherwise, we add the input buffer's value (buffer1) to the mutable
- // buffer's value (buffer2).
- Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
- buffer1.update(0, newValue);
- }
- }
- }
-
- @Override public Object evaluate(Row buffer) {
- if (buffer.isNullAt(0)) {
- // If the buffer value is still null, we return null.
- return null;
- } else {
- // Otherwise, the intermediate sum is the final result.
- return buffer.getDouble(0);
- }
- }
-}
http://git-wip-us.apache.org/repos/asf/spark/blob/742da086/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 84f9159..f245a79 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
import scala.util.Random
+import test.org.apache.spark.sql.MyDoubleAvg
+import test.org.apache.spark.sql.MyDoubleSum
+
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._
+
class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
def inputSchema: StructType = schema
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org