You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/07/05 20:51:59 UTC

git commit: [SPARK-2327] [SQL] Fix nullabilities of Join/Generate/Aggregate.

Repository: spark
Updated Branches:
  refs/heads/master 3da8df939 -> 9d5ecf820


[SPARK-2327] [SQL] Fix nullabilities of Join/Generate/Aggregate.

Fix nullabilities of `Join`/`Generate`/`Aggregate` because:
- Output attributes of opposite side of `OuterJoin` should be nullable.
- Output attributes of generater side of `Generate` should be nullable if `join` is `true` and `outer` is `true`.
- `AttributeReference` of `computedAggregates` of `Aggregate` should be the same as `aggregateExpression`'s.

Author: Takuya UESHIN <ue...@happy-camper.st>

Closes #1266 from ueshin/issues/SPARK-2327 and squashes the following commits:

3ace83a [Takuya UESHIN] Add withNullability to Attribute and use it to change nullabilities.
df1ae53 [Takuya UESHIN] Modify nullabilize to leave attribute if not resolved.
799ce56 [Takuya UESHIN] Add nullabilization to Generate of SparkPlan.
a0fc9bc [Takuya UESHIN] Fix scalastyle errors.
0e31e37 [Takuya UESHIN] Fix Aggregate resultAttribute nullabilities.
09532ec [Takuya UESHIN] Fix Generate output nullabilities.
f20f196 [Takuya UESHIN] Fix Join output nullabilities.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d5ecf82
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d5ecf82
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d5ecf82

Branch: refs/heads/master
Commit: 9d5ecf8205b924dc8a3c13fed68beb78cc5c7553
Parents: 3da8df9
Author: Takuya UESHIN <ue...@happy-camper.st>
Authored: Sat Jul 5 11:51:48 2014 -0700
Committer: Michael Armbrust <mi...@databricks.com>
Committed: Sat Jul 5 11:51:48 2014 -0700

----------------------------------------------------------------------
 .../sql/catalyst/analysis/unresolved.scala      |  2 ++
 .../catalyst/expressions/BoundAttribute.scala   | 16 +++++-----
 .../catalyst/expressions/namedExpressions.scala |  3 +-
 .../catalyst/plans/logical/basicOperators.scala | 31 +++++++++++++++-----
 .../apache/spark/sql/execution/Aggregate.scala  |  4 +--
 .../apache/spark/sql/execution/Generate.scala   | 12 ++++++--
 .../org/apache/spark/sql/execution/joins.scala  | 13 +++++++-
 7 files changed, 60 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index d629172..7abeb03 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -52,6 +52,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo
   override lazy val resolved = false
 
   override def newInstance = this
+  override def withNullability(newNullability: Boolean) = this
   override def withQualifiers(newQualifiers: Seq[String]) = this
 
   // Unresolved attributes are transient at compile time and don't get evaluated during execution.
@@ -95,6 +96,7 @@ case class Star(
   override lazy val resolved = false
 
   override def newInstance = this
+  override def withNullability(newNullability: Boolean) = this
   override def withQualifiers(newQualifiers: Seq[String]) = this
 
   def expand(input: Seq[Attribute]): Seq[NamedExpression] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
index 655d4a0..9ce1f01 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala
@@ -33,14 +33,16 @@ case class BoundReference(ordinal: Int, baseReference: Attribute)
 
   type EvaluatedType = Any
 
-  def nullable = baseReference.nullable
-  def dataType = baseReference.dataType
-  def exprId = baseReference.exprId
-  def qualifiers = baseReference.qualifiers
-  def name = baseReference.name
+  override def nullable = baseReference.nullable
+  override def dataType = baseReference.dataType
+  override def exprId = baseReference.exprId
+  override def qualifiers = baseReference.qualifiers
+  override def name = baseReference.name
 
-  def newInstance = BoundReference(ordinal, baseReference.newInstance)
-  def withQualifiers(newQualifiers: Seq[String]) =
+  override def newInstance = BoundReference(ordinal, baseReference.newInstance)
+  override def withNullability(newNullability: Boolean) =
+    BoundReference(ordinal, baseReference.withNullability(newNullability))
+  override def withQualifiers(newQualifiers: Seq[String]) =
     BoundReference(ordinal, baseReference.withQualifiers(newQualifiers))
 
   override def toString = s"$baseReference:$ordinal"

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 66ae22e..934bad8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -57,6 +57,7 @@ abstract class NamedExpression extends Expression {
 abstract class Attribute extends NamedExpression {
   self: Product =>
 
+  def withNullability(newNullability: Boolean): Attribute
   def withQualifiers(newQualifiers: Seq[String]): Attribute
 
   def toAttribute = this
@@ -133,7 +134,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea
   /**
    * Returns a copy of this [[AttributeReference]] with changed nullability.
    */
-  def withNullability(newNullability: Boolean) = {
+  override def withNullability(newNullability: Boolean) = {
     if (nullable == newNullability) {
       this
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index bac5a72..0728fa7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
+import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.types._
 
 case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -46,10 +46,16 @@ case class Generate(
     child: LogicalPlan)
   extends UnaryNode {
 
-  protected def generatorOutput: Seq[Attribute] =
-    alias
+  protected def generatorOutput: Seq[Attribute] = {
+    val output = alias
       .map(a => generator.output.map(_.withQualifiers(a :: Nil)))
       .getOrElse(generator.output)
+    if (join && outer) {
+      output.map(_.withNullability(true))
+    } else {
+      output
+    }
+  }
 
   override def output =
     if (join) child.output ++ generatorOutput else generatorOutput
@@ -81,11 +87,20 @@ case class Join(
   condition: Option[Expression]) extends BinaryNode {
 
   override def references = condition.map(_.references).getOrElse(Set.empty)
-  override def output = joinType match {
-    case LeftSemi =>
-      left.output
-    case _ =>
-      left.output ++ right.output
+
+  override def output = {
+    joinType match {
+      case LeftSemi =>
+        left.output
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+      case _ =>
+        left.output ++ right.output
+    }
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index d85d2d7..c1ced8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -83,8 +83,8 @@ case class Aggregate(
       case a: AggregateExpression =>
         ComputedAggregate(
           a,
-          BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression],
-          AttributeReference(s"aggResult:$a", a.dataType, nullable = true)())
+          BindReferences.bindReference(a, childOutput),
+          AttributeReference(s"aggResult:$a", a.dataType, a.nullable)())
     }
   }.toArray
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index da1e08b..47b3d00 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.{Generator, JoinedRow, Literal, Projection}
+import org.apache.spark.sql.catalyst.expressions._
 
 /**
  * :: DeveloperApi ::
@@ -39,8 +39,16 @@ case class Generate(
     child: SparkPlan)
   extends UnaryNode {
 
+  protected def generatorOutput: Seq[Attribute] = {
+    if (join && outer) {
+      generator.output.map(_.withNullability(true))
+    } else {
+      generator.output
+    }
+  }
+
   override def output =
-    if (join) child.output ++ generator.output else generator.output
+    if (join) child.output ++ generatorOutput else generatorOutput
 
   override def execute() = {
     if (join) {

http://git-wip-us.apache.org/repos/asf/spark/blob/9d5ecf82/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 32c5f26..7d1f11c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -319,7 +319,18 @@ case class BroadcastNestedLoopJoin(
 
   override def otherCopyArgs = sqlContext :: Nil
 
-  def output = left.output ++ right.output
+  override def output = {
+    joinType match {
+      case LeftOuter =>
+        left.output ++ right.output.map(_.withNullability(true))
+      case RightOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output
+      case FullOuter =>
+        left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
+      case _ =>
+        left.output ++ right.output
+    }
+  }
 
   /** The Streamed Relation */
   def left = streamed