You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2020/06/11 03:52:39 UTC

[GitHub] [spark] HyukjinKwon commented on a change in pull request #28795: [SPARK-31965][TESTS][PYTHON] Move doctests related to Java function registration to test conditionally

HyukjinKwon commented on a change in pull request #28795:
URL: https://github.com/apache/spark/pull/28795#discussion_r438527805



##########
File path: python/pyspark/sql/tests/test_udf.py
##########
@@ -357,6 +359,32 @@ def test_udf_registration_returns_udf(self):
             df.select(add_four("id").alias("plus_four")).collect()
         )
 
+    @unittest.skipIf(not test_compiled, test_not_compiled_message)
+    def test_register_java_function(self):
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength", "test.org.apache.spark.sql.JavaStringLength", IntegerType())
+        [value] = self.spark.sql("SELECT javaStringLength('test')").first()
+        self.assertEqual(value, 4)
+
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength2", "test.org.apache.spark.sql.JavaStringLength")
+        [value] = self.spark.sql("SELECT javaStringLength2('test')").first()
+        self.assertEqual(value, 4)
+
+        self.spark.udf.registerJavaFunction(
+            "javaStringLength3", "test.org.apache.spark.sql.JavaStringLength", "integer")
+        [value] = self.spark.sql("SELECT javaStringLength3('test')").first()
+        self.assertEqual(value, 4)
+
+    @unittest.skipIf(not test_compiled, test_not_compiled_message)
+    def test_register_java_udaf(self):
+        self.spark.udf.registerJavaUDAF("javaUDAF", "test.org.apache.spark.sql.MyDoubleAvg")
+        df = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "a")], ["id", "name"])
+        df.createOrReplaceTempView("df")
+        row = self.spark.sql(
+            "SELECT name, javaUDAF(id) as avg from df group by name order by name desc").first()
+        self.assertEqual(row.asDict(), Row(name='b', avg=102.0).asDict())

Review comment:
       I think we could compare them as are. It's just to prevent an issue such as SPARK-29748 or similar issues in the future.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org