You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2019/02/20 02:25:03 UTC

[spark] branch master updated: [SPARK-26901][SQL][R] Adds child's output into references to avoid column-pruning for vectorized gapply()

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ab850c0  [SPARK-26901][SQL][R] Adds child's output into references to avoid column-pruning for vectorized gapply()
ab850c0 is described below

commit ab850c02f7f2a2fd7edea99c1279d23a01aeab34
Author: Hyukjin Kwon <gu...@apache.org>
AuthorDate: Wed Feb 20 10:24:40 2019 +0800

    [SPARK-26901][SQL][R] Adds child's output into references to avoid column-pruning for vectorized gapply()
    
    ## What changes were proposed in this pull request?
    
    Currently, looks column pruning is done to vectorized `gapply()`. Given R native function could use all referred fields so it shouldn't be pruned. To avoid this, it adds child's output into `references` like `OutputConsumer`.
    
    ```
    $ ./bin/sparkR --conf spark.sql.execution.arrow.enabled=true
    ```
    
    ```r
    df <- createDataFrame(mtcars)
    explain(count(groupBy(gapply(df,
                                 "gear",
                                 function(key, group) {
                                   data.frame(gear = key[[1]], disp = mean(group$disp))
                                 },
                                 structType("gear double, disp double")))), TRUE)
    ```
    
    **Before:**
    
    ```
    == Optimized Logical Plan ==
    Aggregate [count(1) AS count#41L]
    +- Project
       +- FlatMapGroupsInRWithArrow [...]
          +- Project [gear#9]
             +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false
    
    == Physical Plan ==
    *(4) HashAggregate(keys=[], functions=[count(1)], output=[count#41L])
    +- Exchange SinglePartition
       +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#44L])
          +- *(3) Project
             +- FlatMapGroupsInRWithArrow [...]
                +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0
                   +- Exchange hashpartitioning(gear#9, 200)
                      +- *(1) Project [gear#9]
                         +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10]
    ```
    
    **After:**
    
    ```
    == Optimized Logical Plan ==
    Aggregate [count(1) AS count#91L]
    +- Project
       +- FlatMapGroupsInRWithArrow [...]
          +- LogicalRDD [mpg#0, cyl#1, disp#2, hp#3, drat#4, wt#5, qsec#6, vs#7, am#8, gear#9, carb#10], false
    
    == Physical Plan ==
    *(4) HashAggregate(keys=[], functions=[count(1)], output=[count#91L])
    +- Exchange SinglePartition
       +- *(3) HashAggregate(keys=[], functions=[partial_count(1)], output=[count#94L])
          +- *(3) Project
             +- FlatMapGroupsInRWithArrow [...]
                +- *(2) Sort [gear#9 ASC NULLS FIRST], false, 0
                   +- Exchange hashpartitioning(gear#9, 200)
                      +- *(1) Scan ExistingRDD arrow[mpg#0,cyl#1,disp#2,hp#3,drat#4,wt#5,qsec#6,vs#7,am#8,gear#9,carb#10]
    ```
    
    Currently, it adds corrupt values for missing columns (via pruned columnar batches to Arrow writers that requires non-pruned columns) such as:
    
    ```r
    ...
      c(7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 7.90505033345994e-323, 0, 0, 4.17777978645388e-314)
      c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1.04669129845114e+219)
      c(3.4482690635875e-313, 3.4482690635875e-313, 3.4482690635875e-313,
      c(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2.47032822920623e-323)
    ...
    ```
    
    which should be something like:
    
    ```r
    ...
      c(4, 4, 1, 2, 2, 4, 4, 1, 2, 1, 1, 2)
      c(26, 30.4, 15.8, 19.7, 15)
      c(4, 4, 8, 6, 8)
      c(120.3, 95.1, 351, 145, 301)
    ...
    ```
    
    ## How was this patch tested?
    
    Manually tested, and unit tests were added.
    
    The test code is basiaclly:
    
    ```r
    df <- createDataFrame(mtcars)
    count(gapply(df,
                 c("gear"),
                 function(key, group) {
                    stopifnot(all(group$hp > 50))
                    group
                 },
                 schema(df)))
    ```
    
    `mtcars`'s hp is all more then 50.
    
    ```r
    > mtcars$hp > 50
     [1] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
    [16] TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE TRUE
    [31] TRUE TRUE
    ```
    
    However, due to corrpt value, (like 0 or 7.xxxxx), werid values were found. So, it's currently being failed as below in the master:
    
    ```
    Error in handleErrors(returnStatus, conn) :
      org.apache.spark.SparkException: Job aborted due to stage failure: Task 82 in stage 1.0 failed 1 times, most recent failure: Lost task 82.0 in stage 1.0 (TID 198, localhost, executor driver): org.apache.spark.SparkException: R worker exited unexpectedly (crashed)
     Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
    Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
    Error in computeFunc(key, inputData) : all(group$hp > 50) is not TRUE
    ```
    
    I also compared the total length while I am here. Regular `gapply` without Arrow has some holes .. so I had to compare the results with R data frame.
    
    Closes #23810 from HyukjinKwon/SPARK-26901.
    
    Lead-authored-by: Hyukjin Kwon <gu...@apache.org>
    Co-authored-by: Hyukjin Kwon <gu...@gmail.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 R/pkg/tests/fulltests/test_sparkSQL.R                                 | 4 ++++
 .../scala/org/apache/spark/sql/catalyst/plans/logical/object.scala    | 2 ++
 2 files changed, 6 insertions(+)

diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index 9dc699c..4d1360b 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -3564,11 +3564,15 @@ test_that("gapply() Arrow optimization", {
                      stopifnot(is.numeric(key[[1]]))
                    }
                    stopifnot(class(grouped) == "data.frame")
+                   stopifnot(length(colnames(grouped)) == 11)
+                   # mtcars' hp is more then 50.
+                   stopifnot(all(grouped$hp > 50))
                    grouped
                  },
                  schema(df))
     actual <- collect(ret)
     expect_equal(actual, expected)
+    expect_equal(count(ret), nrow(mtcars))
   },
   finally = {
     # Resetting the conf back to default value
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 58bb191..f875af3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -499,6 +499,8 @@ case class FlatMapGroupsInRWithArrow(
     keyDeserializer: Expression,
     groupingAttributes: Seq[Attribute],
     child: LogicalPlan) extends UnaryNode {
+  // This operator always need all columns of its child, even it doesn't reference to.
+  override def references: AttributeSet = child.outputSet
 
   override protected def stringArgs: Iterator[Any] = Iterator(
     inputSchema, StructType.fromAttributes(output), keyDeserializer, groupingAttributes, child)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org