You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/04/28 14:12:35 UTC

[spark] branch branch-3.0 updated: [SPARK-31553][SQL] Revert "[SPARK-29048] Improve performance on Column.isInCollection() with a large size collection"

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new efe700c  [SPARK-31553][SQL] Revert "[SPARK-29048] Improve performance on Column.isInCollection() with a large size collection"
efe700c is described below

commit efe700c90235950df9ced1cb5512770b0e10b9c1
Author: Max Gekk <ma...@gmail.com>
AuthorDate: Tue Apr 28 14:10:50 2020 +0000

    [SPARK-31553][SQL] Revert "[SPARK-29048] Improve performance on Column.isInCollection() with a large size collection"
    
    ### What changes were proposed in this pull request?
    This reverts commit 5631a96367d2576e1e0f95d7ae529468da8f5fa8.
    
    Closes #28328
    
    ### Why are the changes needed?
    The PR  https://github.com/apache/spark/pull/25754 introduced a bug in `isInCollection`. For example, if the SQL config `spark.sql.optimizer.inSetConversionThreshold`is set to 10 (by default):
    ```scala
    val set = (0 to 20).map(_.toString).toSet
    val data = Seq("1").toDF("x")
    data.select($"x".isInCollection(set).as("isInCollection")).show()
    ```
    The function must return **'true'** because "1" is in the set of "0" ... "20" but it returns "false":
    ```
    +--------------+
    |isInCollection|
    +--------------+
    |         false|
    +--------------+
    ```
    
    ### Does this PR introduce any user-facing change?
    Yes
    
    ### How was this patch tested?
    ```
    $ ./build/sbt "test:testOnly *ColumnExpressionSuite"
    ```
    
    Closes #28388 from MaxGekk/fix-isInCollection-revert.
    
    Authored-by: Max Gekk <ma...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit b7cabc80e6df523f0377b651fdbdc2a669c11550)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Column.scala   | 10 +----
 .../apache/spark/sql/ColumnExpressionSuite.scala   | 45 ++++++++--------------
 2 files changed, 18 insertions(+), 37 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 50bc7ec..6913d4e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -31,7 +31,6 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
 import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.expressions.Window
 import org.apache.spark.sql.functions.lit
-import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 private[sql] object Column {
@@ -827,14 +826,7 @@ class Column(val expr: Expression) extends Logging {
    * @group expr_ops
    * @since 2.4.0
    */
-  def isInCollection(values: scala.collection.Iterable[_]): Column = withExpr {
-    val hSet = values.toSet[Any]
-    if (hSet.size > SQLConf.get.optimizerInSetConversionThreshold) {
-      InSet(expr, hSet)
-    } else {
-      In(expr, values.toSeq.map(lit(_).expr))
-    }
-  }
+  def isInCollection(values: scala.collection.Iterable[_]): Column = isin(values.toSeq: _*)
 
   /**
    * A boolean expression that is evaluated to true if the value of this expression is contained
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index b72d92b..8d3b562 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -26,7 +26,7 @@ import org.apache.hadoop.io.{LongWritable, Text}
 import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat}
 import org.scalatest.Matchers._
 
-import org.apache.spark.sql.catalyst.expressions.{In, InSet, Literal, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression}
 import org.apache.spark.sql.execution.ProjectExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -455,36 +455,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession {
 
   test("isInCollection: Scala Collection") {
     val df = Seq((1, "x"), (2, "y"), (3, "z")).toDF("a", "b")
+    // Test with different types of collections
+    checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
+      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
+    checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
+      df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
+    checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
+      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
+    checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
+      df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
 
-    Seq(1, 2).foreach { conf =>
-      withSQLConf(SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> conf.toString) {
-        if (conf <= 1) {
-          assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[InSet], "Expect expr to be InSet")
-        } else {
-          assert($"a".isInCollection(Seq(3, 1)).expr.isInstanceOf[In], "Expect expr to be In")
-        }
-
-        // Test with different types of collections
-        checkAnswer(df.filter($"a".isInCollection(Seq(3, 1))),
-          df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
-        checkAnswer(df.filter($"a".isInCollection(Seq(1, 2).toSet)),
-          df.collect().toSeq.filter(r => r.getInt(0) == 1 || r.getInt(0) == 2))
-        checkAnswer(df.filter($"a".isInCollection(Seq(3, 2).toArray)),
-          df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 2))
-        checkAnswer(df.filter($"a".isInCollection(Seq(3, 1).toList)),
-          df.collect().toSeq.filter(r => r.getInt(0) == 3 || r.getInt(0) == 1))
-
-        val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
+    val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
 
-        val e = intercept[AnalysisException] {
-          df2.filter($"a".isInCollection(Seq($"b")))
-        }
-        Seq("cannot resolve",
-          "due to data type mismatch: Arguments must be same type but were").foreach { s =>
-            assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
-          }
-      }
+    val e = intercept[AnalysisException] {
+      df2.filter($"a".isInCollection(Seq($"b")))
     }
+    Seq("cannot resolve", "due to data type mismatch: Arguments must be same type but were")
+      .foreach { s =>
+        assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT)))
+      }
   }
 
   test("&&") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org