You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/04/30 03:21:47 UTC

[spark] branch branch-3.0 updated: [SPARK-31553][SQL][TESTS][FOLLOWUP] Tests for collection elem types of `isInCollection`

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 281200f  [SPARK-31553][SQL][TESTS][FOLLOWUP] Tests for collection elem types of `isInCollection`
281200f is described below

commit 281200f59b120bd8b2c75269c69e2cdd6fe20e0a
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Thu Apr 30 03:20:10 2020 +0000

    [SPARK-31553][SQL][TESTS][FOLLOWUP] Tests for collection elem types of `isInCollection`
    
    ### What changes were proposed in this pull request?
    - Add tests for different element types of collections that could be passed to `isInCollection`. Added tests for types that can pass the check `In`.`checkInputDataTypes()`.
    - Test different switch thresholds in the `isInCollection: Scala Collection` test.
    
    ### Why are the changes needed?
    To prevent regressions like introduced by https://github.com/apache/spark/pull/25754 and reverted by https://github.com/apache/spark/pull/28388
    
    ### Does this PR introduce any user-facing change?
    No
    
    ### How was this patch tested?
    By existing and new tests in `ColumnExpressionSuite`
    
    Closes #28405 from MaxGekk/test-isInCollection.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 91648654da259c63178f3fb3f94e3e62e1ef1e45)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 93 +++++++++++++++++-----
 1 file changed, 75 insertions(+), 18 deletions(-)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 8d3b562..4bf19532 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import java.sql.Date
+import java.sql.{Date, Timestamp}
 import java.util.Locale
 
 import scala.collection.JavaConverters._
@@ -454,26 +454,83 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
   }
 
   test("isInCollection: Scala Collection") {
-    val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
-    // Test with different types of collections
-    checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
-      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
-    checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
-      df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
-    checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
-      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
-    checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
-      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
+    Seq(0, 1, 10).foreach { optThreshold =>
+      Seq(0, 1, 10).foreach { switchThreshold =>
+        withSQLConf(
+          SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> optThreshold.toString,
+          SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> switchThreshold.toString) {
+          val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
+          // Test with different types of collections
+          checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
+            df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
+          checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
+            df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
+          checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
+            df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
+          checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
+            df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
+
+          val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
 
-    val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
-
-    val e = intercept[AnalysisException] {
-      df2.filter($"a".isInCollection(Seq($"b")))
+          val e = intercept[AnalysisException] {
+            df2.filter($"a".isInCollection(Seq($"b")))
+          }
+          Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
+            .foreach { s =>
+              assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
+            }
+        }
+      }
     }
-    Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
-      .foreach { s =>
-        assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
+  }
+
+  test("SPARK-31553: isInCollection - collection element types") {
+    val expected = Seq(Row(true), Row(false))
+    Seq(0, 1, 10).foreach { optThreshold =>
+      Seq(0, 1, 10).foreach { switchThreshold =>
+        withSQLConf(
+          SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> optThreshold.toString,
+          SQLConf.OPTIMIZER_INSET_SWITCH_THRESHOLD.key -> switchThreshold.toString) {
+          checkAnswer(Seq(0).toDS.select($"value".isInCollection(Seq(null))), Seq(Row(null)))
+          checkAnswer(
+            Seq(true).toDS.select($"value".isInCollection(Seq(true, false))),
+            Seq(Row(true)))
+          checkAnswer(
+            Seq(0.toByte, 1.toByte).toDS.select($"value".isInCollection(Seq(0.toByte, 2.toByte))),
+            expected)
+          checkAnswer(
+            Seq(0.toShort, 1.toShort).toDS
+              .select($"value".isInCollection(Seq(0.toShort, 2.toShort))),
+            expected)
+          checkAnswer(Seq(0, 1).toDS.select($"value".isInCollection(Seq(0, 2))), expected)
+          checkAnswer(Seq(0L, 1L).toDS.select($"value".isInCollection(Seq(0L, 2L))), expected)
+          checkAnswer(Seq(0.0f, 1.0f).toDS
+            .select($"value".isInCollection(Seq(0.0f, 2.0f))), expected)
+          checkAnswer(Seq(0.0D, 1.0D).toDS
+            .select($"value".isInCollection(Seq(0.0D, 2.0D))), expected)
+          checkAnswer(
+            Seq(BigDecimal(0), BigDecimal(2)).toDS
+              .select($"value".isInCollection(Seq(BigDecimal(0), BigDecimal(1)))),
+            expected)
+          checkAnswer(
+            Seq("abc", "def").toDS.select($"value".isInCollection(Seq("abc", "xyz"))),
+            expected)
+          checkAnswer(
+            Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-05-01")).toDS
+              .select($"value".isInCollection(
+                Seq(Date.valueOf("2020-04-29"), Date.valueOf("2020-04-30")))),
+            expected)
+          checkAnswer(
+            Seq(new Timestamp(0), new Timestamp(2)).toDS
+              .select($"value".isInCollection(Seq(new Timestamp(0), new Timestamp(1)))),
+            expected)
+          checkAnswer(
+            Seq(Array("a", "b"), Array("c", "d")).toDS
+              .select($"value".isInCollection(Seq(Array("a", "b"), Array("x", "z")))),
+            expected)
+        }
       }
+    }
   }
 
   test("&&") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org