You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by yh...@apache.org on 2016/04/19 19:52:03 UTC

spark git commit: [SPARK-14675][SQL] ClassFormatError when use Seq as Aggregator buffer type

Repository: spark
Updated Branches:
  refs/heads/master 947b9020b -> 5cb2e3360


[SPARK-14675][SQL] ClassFormatError when use Seq as Aggregator buffer type

## What changes were proposed in this pull request?

After https://github.com/apache/spark/pull/12067, we now use expressions to do the aggregation in `TypedAggregateExpression`. To implement buffer merge, we produce a new buffer deserializer expression by replacing `AttributeReference` with right-side buffer attribute, like other `DeclarativeAggregate`s do, and finally combine the left and right buffer deserializer with `Invoke`.

However, after https://github.com/apache/spark/pull/12338, we will add loop variable to class members when codegen `MapObjects`. If the `Aggregator` buffer type is `Seq`, which is implemented by `MapObjects` expression, we will add the same loop variable to class members twice(by left and right buffer deserializer), which cause the `ClassFormatError`.

This PR fixes this issue by calling `distinct` before declare the class menbers.

## How was this patch tested?

new regression test in `DatasetAggregatorSuite`

Author: Wenchen Fan <we...@databricks.com>

Closes #12468 from cloud-fan/bug.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5cb2e336
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5cb2e336
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5cb2e336

Branch: refs/heads/master
Commit: 5cb2e3360985bc9e67aee038befa93c258f2016a
Parents: 947b902
Author: Wenchen Fan <we...@databricks.com>
Authored: Tue Apr 19 10:51:58 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Tue Apr 19 10:51:58 2016 -0700

----------------------------------------------------------------------
 .../expressions/EquivalentExpressions.scala     | 13 +++++++++++--
 .../expressions/codegen/CodeGenerator.scala     |  8 ++++++--
 .../spark/sql/DatasetAggregatorSuite.scala      | 20 ++++++++++++++++++++
 3 files changed, 37 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5cb2e336/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 8d8cc15..607c7c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -69,8 +69,17 @@ class EquivalentExpressions {
    */
   def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
     val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
-    // the children of CodegenFallback will not be used to generate code (call eval() instead)
-    if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) {
+    // There are some special expressions that we should not recurse into children.
+    //   1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
+    //   2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
+    val shouldRecurse = root match {
+      // TODO: some expressions implements `CodegenFallback` but can still do codegen,
+      // e.g. `CaseWhen`, we should support them.
+      case _: CodegenFallback => false
+      case _: ReferenceToExpressions => false
+      case _ => true
+    }
+    if (!skip && !addExpr(root) && shouldRecurse) {
       root.children.foreach(addExprTree(_, ignoreLeaf))
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/5cb2e336/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 38ac13b..d29c27c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -110,13 +110,17 @@ class CodegenContext {
   }
 
   def declareMutableStates(): String = {
-    mutableStates.map { case (javaType, variableName, _) =>
+    // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
+    // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
+    mutableStates.distinct.map { case (javaType, variableName, _) =>
       s"private $javaType $variableName;"
     }.mkString("\n")
   }
 
   def initMutableStates(): String = {
-    mutableStates.map(_._3).mkString("\n")
+    // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
+    // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
+    mutableStates.distinct.map(_._3).mkString("\n")
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/5cb2e336/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 0d84a59..6eae3ed 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.language.postfixOps
 
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.expressions.scala.typed
 import org.apache.spark.sql.functions._
@@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] {
 }
 
 
+object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
+  def zero: Seq[Int] = Nil
+  def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
+  def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
+  def finish(r: Seq[Int]): Seq[Int] = r
+  override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+  override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+}
+
+
 class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
   extends Aggregator[IN, OUT, OUT] {
 
@@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
     val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
     checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
   }
+
+  test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") {
+    val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS()
+
+    checkDataset(
+      ds.groupByKey(_.b).agg(SeqAgg.toColumn),
+      "a" -> Seq(1, 2)
+    )
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org