You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/16 15:22:51 UTC

[spark] branch master updated: [SPARK-37961][SQL] Override maxRows/maxRowsPerPartition for some logical operators

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e841fa3f769 [SPARK-37961][SQL] Override maxRows/maxRowsPerPartition for some logical operators
e841fa3f769 is described below

commit e841fa3f76988b532bb609c6e9de047703ff2783
Author: Ruifeng Zheng <ru...@apache.org>
AuthorDate: Thu Jun 16 23:22:35 2022 +0800

    [SPARK-37961][SQL] Override maxRows/maxRowsPerPartition for some logical operators
    
    ### What changes were proposed in this pull request?
    
    1, override `maxRowsPerPartition` in `Sort`,`Expand`,`Sample`,`CollectMetrics`;
    2, override `maxRows` in  `Except`,`Expand`,`CollectMetrics`;
    
    ### Why are the changes needed?
    
    to provide an accurate value if possible
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    added testsuites
    
    Closes #35250 from zhengruifeng/add_some_maxRows_maxRowsPerPartition.
    
    Authored-by: Ruifeng Zheng <ru...@apache.org>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../plans/logical/basicLogicalOperators.scala      | 28 ++++++++++++++++++++--
 .../sql/catalyst/plans/LogicalPlanSuite.scala      | 26 ++++++++++++++++++++
 2 files changed, 52 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 11d68294023..32045ff5a52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -251,6 +251,8 @@ case class Except(
 
   override protected lazy val validConstraints: ExpressionSet = leftConstraints
 
+  override def maxRows: Option[Long] = left.maxRows
+
   override protected def withNewChildrenInternal(
     newLeft: LogicalPlan, newRight: LogicalPlan): Except = copy(left = newLeft, right = newRight)
 }
@@ -758,6 +760,9 @@ case class Sort(
     child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
   override def maxRows: Option[Long] = child.maxRows
+  override def maxRowsPerPartition: Option[Long] = {
+    if (global) maxRows else child.maxRowsPerPartition
+  }
   override def outputOrdering: Seq[SortOrder] = order
   final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
   override protected def withNewChildInternal(newChild: LogicalPlan): Sort = copy(child = newChild)
@@ -1163,6 +1168,19 @@ case class Expand(
   override lazy val references: AttributeSet =
     AttributeSet(projections.flatten.flatMap(_.references))
 
+  override def maxRows: Option[Long] = child.maxRows match {
+    case Some(m) =>
+      val n = BigInt(m) * projections.length
+      if (n.isValidLong) Some(n.toLong) else None
+    case _ => None
+  }
+  override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition match {
+    case Some(m) =>
+      val n = BigInt(m) * projections.length
+      if (n.isValidLong) Some(n.toLong) else None
+    case _ => maxRows
+  }
+
   override def metadataOutput: Seq[Attribute] = Nil
 
   override def producedAttributes: AttributeSet = AttributeSet(output diff child.output)
@@ -1432,11 +1450,15 @@ case class Sample(
       s"Sampling fraction ($fraction) must be on interval [0, 1] without replacement")
   }
 
+  // when withReplacement is true, PoissonSampler is applied in SampleExec,
+  // which may output more rows than child.
   override def maxRows: Option[Long] = {
-    // when withReplacement is true, PoissonSampler is applied in SampleExec,
-    // which may output more rows than child.maxRows.
     if (withReplacement) None else child.maxRows
   }
+  override def maxRowsPerPartition: Option[Long] = {
+    if (withReplacement) None else child.maxRowsPerPartition
+  }
+
   override def output: Seq[Attribute] = child.output
 
   override protected def withNewChildInternal(newChild: LogicalPlan): Sample =
@@ -1626,6 +1648,8 @@ case class CollectMetrics(
     name.nonEmpty && metrics.nonEmpty && metrics.forall(_.resolved) && childrenResolved
   }
 
+  override def maxRows: Option[Long] = child.maxRows
+  override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
   override def output: Seq[Attribute] = child.output
 
   override protected def withNewChildInternal(newChild: LogicalPlan): CollectMetrics =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 5dac35a33a6..1d533e9d0d4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -113,4 +113,30 @@ class LogicalPlanSuite extends SparkFunSuite {
     assert(query.maxRows.isEmpty)
     assert(query.maxRowsPerPartition.isEmpty)
   }
+
+  test("SPARK-37961: add maxRows/maxRowsPerPartition for some logical nodes") {
+    val range = Range(0, 100, 1, 3)
+    assert(range.maxRows === Some(100))
+    assert(range.maxRowsPerPartition === Some(34))
+
+    val sort = Sort(Seq('id.asc), false, range)
+    assert(sort.maxRows === Some(100))
+    assert(sort.maxRowsPerPartition === Some(34))
+    val sort2 = Sort(Seq('id.asc), true, range)
+    assert(sort2.maxRows === Some(100))
+    assert(sort2.maxRowsPerPartition === Some(100))
+
+    val c1 = Literal(1).as('a).toAttribute.newInstance().withNullability(true)
+    val c2 = Literal(2).as('b).toAttribute.newInstance().withNullability(true)
+    val expand = Expand(
+      Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
+      Seq(c1, c2),
+      sort.select('id as 'a, 'id + 1 as 'b))
+    assert(expand.maxRows === Some(200))
+    assert(expand.maxRowsPerPartition === Some(68))
+
+    val sample = Sample(0.1, 0.9, false, 42, expand)
+    assert(sample.maxRows === Some(200))
+    assert(sample.maxRowsPerPartition === Some(68))
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org