You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2022/07/15 03:45:11 UTC

[spark] branch master updated: [SPARK-39748][SQL][SS][FOLLOWUP] Fix a bug on column stat in LogicalRDD on mismatching exprIDs

This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c80ed8547d [SPARK-39748][SQL][SS][FOLLOWUP] Fix a bug on column stat in LogicalRDD on mismatching exprIDs
3c80ed8547d is described below

commit 3c80ed8547d550a056788267bbb395fde8b2c030
Author: Jungtaek Lim <ka...@gmail.com>
AuthorDate: Fri Jul 15 12:44:53 2022 +0900

    [SPARK-39748][SQL][SS][FOLLOWUP] Fix a bug on column stat in LogicalRDD on mismatching exprIDs
    
    ### What changes were proposed in this pull request?
    
    This PR fixes a bug on #37161 (described the bug in below section) via making sure the output columns in LogicalRDD are always the same with output columns in originLogicalPlan in LogicalRDD, which is needed to inherit the column stats.
    
    ### Why are the changes needed?
    
    Stats for columns in originLogicalPlan refer to the columns in originLogicalPlan, which could be different from the columns in output of LogicalRDD in terms of expression ID.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New UT
    
    Closes #37187 from HeartSaVioR/SPARK-39748-FOLLOWUP-2.
    
    Authored-by: Jungtaek Lim <ka...@gmail.com>
    Signed-off-by: Jungtaek Lim <ka...@gmail.com>
---
 .../apache/spark/sql/execution/ExistingRDD.scala   |  12 ++-
 .../streaming/sources/ForeachBatchSink.scala       |   7 +-
 .../org/apache/spark/sql/DataFrameSuite.scala      | 113 ++++++++++++++++++++-
 3 files changed, 126 insertions(+), 6 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index bf9ef6991e3..149e70e56d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -116,10 +116,20 @@ case class LogicalRDD(
       case e: Attribute => rewrite.getOrElse(e, e)
     }.asInstanceOf[SortOrder])
 
+    val rewrittenOriginLogicalPlan = originLogicalPlan.map { plan =>
+      assert(output == plan.output, "The output columns are expected to the same for output " +
+        s"and originLogicalPlan. output: $output / output in originLogicalPlan: ${plan.output}")
+
+      val projectList = output.map { attr =>
+        Alias(attr, attr.name)(exprId = rewrite(attr).exprId)
+      }
+      Project(projectList, plan)
+    }
+
     LogicalRDD(
       output.map(rewrite),
       rdd,
-      originLogicalPlan,
+      rewrittenOriginLogicalPlan,
       rewrittenPartitioning,
       rewrittenOrdering,
       isStreaming
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
index 1c6bca241af..395ed056be2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachBatchSink.scala
@@ -30,10 +30,13 @@ class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: Expr
   override def addBatch(batchId: Long, data: DataFrame): Unit = {
     val rdd = data.queryExecution.toRdd
     val executedPlan = data.queryExecution.executedPlan
+    val analyzedPlanWithoutMarkerNode = eliminateWriteMarkerNode(data.queryExecution.analyzed)
+    // assertion on precondition
+    assert(data.logicalPlan.output == analyzedPlanWithoutMarkerNode.output)
     val node = LogicalRDD(
-      data.schema.toAttributes,
+      data.logicalPlan.output,
       rdd,
-      Some(eliminateWriteMarkerNode(data.queryExecution.analyzed)),
+      Some(analyzedPlanWithoutMarkerNode),
       executedPlan.outputPartitioning,
       executedPlan.outputOrdering)(data.sparkSession)
     implicit val enc = encoder
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 41593c701a7..e802159f263 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -32,13 +32,14 @@ import org.scalatest.matchers.should.Matchers._
 import org.apache.spark.SparkException
 import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd}
 import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
-import org.apache.spark.sql.catalyst.expressions.Uuid
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Uuid}
 import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, OneRowRelation}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.connector.FakeV2Provider
-import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCodegenExec}
+import org.apache.spark.sql.execution.{FilterExec, LogicalRDD, QueryExecution, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
 import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
@@ -2010,6 +2011,68 @@ class DataFrameSuite extends QueryTest
     }
   }
 
+  test("SPARK-39748: build the stats for LogicalRDD based on originLogicalPlan") {
+    def buildExpectedColumnStats(attrs: Seq[Attribute]): AttributeMap[ColumnStat] = {
+      AttributeMap(
+        attrs.map {
+          case attr if attr.dataType == BooleanType =>
+            attr -> ColumnStat(
+              distinctCount = Some(2),
+              min = Some(false),
+              max = Some(true),
+              nullCount = Some(0),
+              avgLen = Some(1),
+              maxLen = Some(1))
+
+          case attr if attr.dataType == ByteType =>
+            attr -> ColumnStat(
+              distinctCount = Some(2),
+              min = Some(1),
+              max = Some(2),
+              nullCount = Some(0),
+              avgLen = Some(1),
+              maxLen = Some(1))
+
+          case attr => attr -> ColumnStat()
+        }
+      )
+    }
+
+    val outputList = Seq(
+      AttributeReference("cbool", BooleanType)(),
+      AttributeReference("cbyte", BooleanType)()
+    )
+
+    val expectedSize = 16
+    val statsPlan = OutputListAwareStatsTestPlan(
+      outputList = outputList,
+      rowCount = 2,
+      size = Some(expectedSize))
+
+    withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
+      val df = Dataset.ofRows(spark, statsPlan)
+
+      val logicalRDD = LogicalRDD(
+        df.logicalPlan.output, spark.sparkContext.emptyRDD, Some(df.queryExecution.analyzed),
+        isStreaming = true)(spark)
+
+      val stats = logicalRDD.computeStats()
+      val expectedStats = Statistics(sizeInBytes = expectedSize, rowCount = Some(2),
+        attributeStats = buildExpectedColumnStats(logicalRDD.output))
+      assert(stats === expectedStats)
+
+      // This method re-issues expression IDs for all outputs. We expect column stats to be
+      // reflected as well.
+      val newLogicalRDD = logicalRDD.newInstance()
+      val newStats = newLogicalRDD.computeStats()
+      // LogicalRDD.newInstance adds projection to originLogicalPlan, which triggers estimation
+      // on sizeInBytes. We don't intend to check the estimated value.
+      val newExpectedStats = Statistics(sizeInBytes = newStats.sizeInBytes, rowCount = Some(2),
+        attributeStats = buildExpectedColumnStats(newLogicalRDD.output))
+      assert(newStats === newExpectedStats)
+    }
+  }
+
   test("SPARK-10656: completely support special chars") {
     val df = Seq(1 -> "a").toDF("i_$.a", "d^'a.")
     checkAnswer(df.select(df("*")), Row(1, "a"))
@@ -3249,3 +3312,47 @@ class DataFrameSuite extends QueryTest
 case class GroupByKey(a: Int, b: Int)
 
 case class Bar2(s: String)
+
+/**
+ * This class is used for unit-testing. It's a logical plan whose output and stats are passed in.
+ */
+case class OutputListAwareStatsTestPlan(
+    outputList: Seq[Attribute],
+    rowCount: BigInt,
+    size: Option[BigInt] = None) extends LeafNode with MultiInstanceRelation {
+  override def output: Seq[Attribute] = outputList
+  override def computeStats(): Statistics = {
+    val columnInfo = outputList.map { attr =>
+      attr.dataType match {
+        case BooleanType =>
+          attr -> ColumnStat(
+            distinctCount = Some(2),
+            min = Some(false),
+            max = Some(true),
+            nullCount = Some(0),
+            avgLen = Some(1),
+            maxLen = Some(1))
+
+        case ByteType =>
+          attr -> ColumnStat(
+            distinctCount = Some(2),
+            min = Some(1),
+            max = Some(2),
+            nullCount = Some(0),
+            avgLen = Some(1),
+            maxLen = Some(1))
+
+        case _ =>
+          attr -> ColumnStat()
+      }
+    }
+    val attrStats = AttributeMap(columnInfo)
+
+    Statistics(
+      // If sizeInBytes is useless in testing, we just use a fake value
+      sizeInBytes = size.getOrElse(Int.MaxValue),
+      rowCount = Some(rowCount),
+      attributeStats = attrStats)
+  }
+  override def newInstance(): LogicalPlan = copy(outputList = outputList.map(_.newInstance()))
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org