You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by mgaido91 <gi...@git.apache.org> on 2018/10/01 07:59:34 UTC

[GitHub] spark pull request #22524: [SPARK-25497][SQL] Limit operation within whole s...

Github user mgaido91 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22524#discussion_r221520640
  
    --- Diff: sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---
    @@ -2850,6 +2849,80 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
         result.rdd.isEmpty
       }
     
    +  test("SPARK-25497: limit operation within whole stage codegen should not " +
    +    "consume all the inputs") {
    +
    +    val aggDF = spark.range(0, 100, 1, 1)
    +      .groupBy("id")
    +      .count().limit(1).filter('count > 0)
    +    aggDF.collect()
    +    val aggNumRecords = aggDF.queryExecution.sparkPlan.collect {
    +      case h: HashAggregateExec => h
    +    }.map { hashNode =>
    +      hashNode.metrics("numOutputRows").value
    +    }.sum
    +    // The first hash aggregate node outputs 100 records.
    +    // The second hash aggregate before local limit outputs 1 record.
    +    assert(aggNumRecords == 101)
    +
    +    val aggNoGroupingDF = spark.range(0, 100, 1, 1)
    +      .groupBy()
    +      .count().limit(1).filter('count > 0)
    +    aggNoGroupingDF.collect()
    +    val aggNoGroupingNumRecords = aggNoGroupingDF.queryExecution.sparkPlan.collect {
    +      case h: HashAggregateExec => h
    +    }.map { hashNode =>
    +      hashNode.metrics("numOutputRows").value
    +    }.sum
    +    assert(aggNoGroupingNumRecords == 2)
    +
    +    // Sets `TOP_K_SORT_FALLBACK_THRESHOLD` to a low value because we don't want sort + limit
    +    // be planned as `TakeOrderedAndProject` node.
    +    withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1") {
    +      val sortDF = spark.range(0, 100, 1, 1)
    +        .filter('id >= 0)
    +        .limit(10)
    +        .sortWithinPartitions("id")
    +        // use non-deterministic expr to prevent filter be pushed down.
    +        .selectExpr("rand() + id as id2")
    +        .filter('id2 >= 0)
    +        .limit(5)
    +        .selectExpr("1 + id2 as id3")
    +      sortDF.collect()
    +      val sortNumRecords = sortDF.queryExecution.sparkPlan.collect {
    +        case l@LocalLimitExec(_, f: FilterExec) => f
    +      }.map { filterNode =>
    +        filterNode.metrics("numOutputRows").value
    +      }
    +      assert(sortNumRecords.sorted === Seq(5, 10))
    +    }
    +
    +    val filterDF = spark.range(0, 100, 1, 1).filter('id >= 0)
    +      .selectExpr("id + 1 as id2").limit(1).filter('id > 50)
    +    filterDF.collect()
    +    val filterNumRecords = filterDF.queryExecution.sparkPlan.collect {
    +      case f@FilterExec(_, r: RangeExec) => f
    +    }.map { case filterNode =>
    +      filterNode.metrics("numOutputRows").value
    +    }.head
    +    assert(filterNumRecords == 1)
    +
    +    val twoLimitsDF = spark.range(0, 100, 1, 1)
    +      .filter('id >= 0)
    +      .limit(1)
    +      .selectExpr("id + 1 as id2")
    +      .limit(2)
    +      .filter('id2 >= 0)
    +    twoLimitsDF.collect()
    +    val twoLimitsDFNumRecords = twoLimitsDF.queryExecution.sparkPlan.collect {
    +      case f@FilterExec(_, _: RangeExec) => f
    --- End diff --
    
    nit: spaces


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org