You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/07/25 21:10:06 UTC

spark git commit: [SPARK-9192][SQL] add initialization phase for nondeterministic expression

Repository: spark
Updated Branches:
  refs/heads/master e2ec018e3 -> 2c94d0f24


[SPARK-9192][SQL] add initialization phase for nondeterministic expression

Currently nondeterministic expression is broken without a explicit initialization phase.

Let me take `MonotonicallyIncreasingID` as an example. This expression need a mutable state to remember how many times it has been evaluated, so we use `transient var count: Long` there. By being transient, the `count` will be reset to 0 and **only** to 0 when serialize and deserialize it, as deserialize transient variable will result to default value. There is *no way* to use another initial value for `count`, until we add the explicit initialization phase.

Another use case is local execution for `LocalRelation`, there is no serialize and deserialize phase and thus we can't reset mutable states for it.

Author: Wenchen Fan <cl...@outlook.com>

Closes #7535 from cloud-fan/init and squashes the following commits:

6c6f332 [Wenchen Fan] add test
ef68ff4 [Wenchen Fan] fix comments
9eac85e [Wenchen Fan] move init code to interpreted class
bb7d838 [Wenchen Fan] pulls out nondeterministic expressions into a project
b4a4fc7 [Wenchen Fan] revert a refactor
86fee36 [Wenchen Fan] add initialization phase for nondeterministic expression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2c94d0f2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2c94d0f2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2c94d0f2

Branch: refs/heads/master
Commit: 2c94d0f24a37fa079b56d534b0b0a4574209215b
Parents: e2ec018
Author: Wenchen Fan <cl...@outlook.com>
Authored: Sat Jul 25 12:10:02 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat Jul 25 12:10:02 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  35 ++++++-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  19 ++--
 .../sql/catalyst/expressions/Expression.scala   |  21 +++-
 .../sql/catalyst/expressions/Projection.scala   |  10 ++
 .../sql/catalyst/expressions/predicates.scala   |   4 +
 .../spark/sql/catalyst/expressions/random.scala |  12 ++-
 .../catalyst/plans/logical/basicOperators.scala |   3 +-
 .../sql/catalyst/analysis/AnalysisSuite.scala   |  96 ++++++++---------
 .../sql/catalyst/analysis/AnalysisTest.scala    | 105 +++++++++++++++++++
 .../expressions/ExpressionEvalHelper.scala      |   4 +
 .../expressions/MonotonicallyIncreasingID.scala |  13 ++-
 .../expressions/SparkPartitionID.scala          |   8 +-
 12 files changed, 254 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index e916887..a723e92 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -19,10 +19,11 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression2, AggregateFunction2}
-import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf}
 import org.apache.spark.sql.types._
 import scala.collection.mutable.ArrayBuffer
 
@@ -78,7 +79,9 @@ class Analyzer(
       GlobalAggregates ::
       UnresolvedHavingClauseAttributes ::
       HiveTypeCoercion.typeCoercionRules ++
-      extendedResolutionRules : _*)
+      extendedResolutionRules : _*),
+    Batch("Nondeterministic", Once,
+      PullOutNondeterministic)
   )
 
   /**
@@ -910,6 +913,34 @@ class Analyzer(
         Project(finalProjectList, withWindow)
     }
   }
+
+  /**
+   * Pulls out nondeterministic expressions from LogicalPlan which is not Project or Filter,
+   * put them into an inner Project and finally project them away at the outer Project.
+   */
+  object PullOutNondeterministic extends Rule[LogicalPlan] {
+    override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+      case p: Project => p
+      case f: Filter => f
+
+      // todo: It's hard to write a general rule to pull out nondeterministic expressions
+      // from LogicalPlan, currently we only do it for UnaryNode which has same output
+      // schema with its child.
+      case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
+        val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
+          val ne = e match {
+            case n: NamedExpression => n
+            case _ => Alias(e, "_nondeterministic")()
+          }
+          new TreeNodeRef(e) -> ne
+        }.toMap
+        val newPlan = p.transformExpressions { case e =>
+          nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
+        }
+        val newChild = Project(p.child.output ++ nondeterministicExprs.values, p.child)
+        Project(p.output, newPlan.withNewChildren(newChild :: Nil))
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 81d473c..a373714 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
 
@@ -38,10 +37,10 @@ trait CheckAnalysis {
     throw new AnalysisException(msg)
   }
 
-  def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
+  protected def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
     exprs.flatMap(_.collect {
-      case e: Generator => true
-    }).nonEmpty
+      case e: Generator => e
+    }).length > 1
   }
 
   def checkAnalysis(plan: LogicalPlan): Unit = {
@@ -137,13 +136,21 @@ trait CheckAnalysis {
               s"""
                  |Failure when resolving conflicting references in Join:
                  |$plan
-                  |Conflicting attributes: ${conflictingAttributes.mkString(",")}
-                  |""".stripMargin)
+                 |Conflicting attributes: ${conflictingAttributes.mkString(",")}
+                 |""".stripMargin)
 
           case o if !o.resolved =>
             failAnalysis(
               s"unresolved operator ${operator.simpleString}")
 
+          case o if o.expressions.exists(!_.deterministic) &&
+            !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] =>
+            failAnalysis(
+              s"""nondeterministic expressions are only allowed in Project or Filter, found:
+                 | ${o.expressions.map(_.prettyString).mkString(",")}
+                 |in operator ${operator.simpleString}
+             """.stripMargin)
+
           case _ => // Analysis successful!
         }
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 3f72e6e..cb4c3f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -196,7 +196,26 @@ trait Unevaluable extends Expression {
  * An expression that is nondeterministic.
  */
 trait Nondeterministic extends Expression {
-  override def deterministic: Boolean = false
+  final override def deterministic: Boolean = false
+  final override def foldable: Boolean = false
+
+  private[this] var initialized = false
+
+  final def initialize(): Unit = {
+    if (!initialized) {
+      initInternal()
+      initialized = true
+    }
+  }
+
+  protected def initInternal(): Unit
+
+  final override def eval(input: InternalRow = null): Any = {
+    require(initialized, "nondeterministic expression should be initialized before evaluate")
+    evalInternal(input)
+  }
+
+  protected def evalInternal(input: InternalRow): Any
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index fb873e7..c1ed9cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -31,6 +31,11 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
+  expressions.foreach(_.foreach {
+    case n: Nondeterministic => n.initialize()
+    case _ =>
+  })
+
   // null check is required for when Kryo invokes the no-arg constructor.
   protected val exprArray = if (expressions != null) expressions.toArray else null
 
@@ -57,6 +62,11 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
   def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
     this(expressions.map(BindReferences.bindReference(_, inputSchema)))
 
+  expressions.foreach(_.foreach {
+    case n: Nondeterministic => n.initialize()
+    case _ =>
+  })
+
   private[this] val exprArray = expressions.toArray
   private[this] var mutableRow: MutableRow = new GenericMutableRow(exprArray.length)
   def currentValue: InternalRow = mutableRow

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3f1bd2a..5bfe1ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -30,6 +30,10 @@ object InterpretedPredicate {
     create(BindReferences.bindReference(expression, inputSchema))
 
   def create(expression: Expression): (InternalRow => Boolean) = {
+    expression.foreach {
+      case n: Nondeterministic => n.initialize()
+      case _ =>
+    }
     (r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
index aef24a5..8f30519 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/random.scala
@@ -38,9 +38,13 @@ abstract class RDG extends LeafExpression with Nondeterministic {
 
   /**
    * Record ID within each partition. By being transient, the Random Number Generator is
-   * reset every time we serialize and deserialize it.
+   * reset every time we serialize and deserialize and initialize it.
    */
-  @transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+  @transient protected var rng: XORShiftRandom = _
+
+  override protected def initInternal(): Unit = {
+    rng = new XORShiftRandom(seed + TaskContext.getPartitionId)
+  }
 
   override def nullable: Boolean = false
 
@@ -49,7 +53,7 @@ abstract class RDG extends LeafExpression with Nondeterministic {
 
 /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
 case class Rand(seed: Long) extends RDG {
-  override def eval(input: InternalRow): Double = rng.nextDouble()
+  override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
 
   def this() = this(Utils.random.nextLong())
 
@@ -72,7 +76,7 @@ case class Rand(seed: Long) extends RDG {
 
 /** Generate a random column with i.i.d. gaussian random distribution. */
 case class Randn(seed: Long) extends RDG {
-  override def eval(input: InternalRow): Double = rng.nextGaussian()
+  override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
 
   def this() = this(Utils.random.nextLong())
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 57a1282..8e1a236 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.plans.logical
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression2
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.types._
 import org.apache.spark.util.collection.OpenHashSet
@@ -379,7 +378,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
   override def output: Seq[Attribute] = child.output
 
   override lazy val statistics: Statistics = {
-    val limit = limitExpr.eval(null).asInstanceOf[Int]
+    val limit = limitExpr.eval().asInstanceOf[Int]
     val sizeInBytes = (limit: Long) * output.map(a => a.dataType.defaultSize).sum
     Statistics(sizeInBytes = sizeInBytes)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 7e67427..ed645b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -17,10 +17,6 @@
 
 package org.apache.spark.sql.catalyst.analysis
 
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
@@ -28,6 +24,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 
+// todo: remove this and use AnalysisTest instead.
 object AnalysisSuite {
   val caseSensitiveConf = new SimpleCatalystConf(true)
   val caseInsensitiveConf = new SimpleCatalystConf(false)
@@ -55,7 +52,7 @@ object AnalysisSuite {
     AttributeReference("a", StringType)(),
     AttributeReference("b", StringType)(),
     AttributeReference("c", DoubleType)(),
-    AttributeReference("d", DecimalType.SYSTEM_DEFAULT)(),
+    AttributeReference("d", DecimalType(10, 2))(),
     AttributeReference("e", ShortType)())
 
   val nestedRelation = LocalRelation(
@@ -81,8 +78,7 @@ object AnalysisSuite {
 }
 
 
-class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
-  import AnalysisSuite._
+class AnalysisSuite extends AnalysisTest {
 
   test("union project *") {
     val plan = (1 to 100)
@@ -91,7 +87,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
         a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
       }
 
-    assert(caseInsensitiveAnalyzer.execute(plan).resolved)
+    assertAnalysisSuccess(plan)
   }
 
   test("check project's resolved") {
@@ -106,61 +102,40 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
   }
 
   test("analyze project") {
-    assert(
-      caseSensitiveAnalyzer.execute(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
-        Project(testRelation.output, testRelation))
-
-    assert(
-      caseSensitiveAnalyzer.execute(
-        Project(Seq(UnresolvedAttribute("TbL.a")),
-          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
-        Project(testRelation.output, testRelation))
-
-    val e = intercept[AnalysisException] {
-      caseSensitiveAnalyze(
-        Project(Seq(UnresolvedAttribute("tBl.a")),
-          UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
-    }
-    assert(e.getMessage().toLowerCase.contains("cannot resolve"))
-
-    assert(
-      caseInsensitiveAnalyzer.execute(
-        Project(Seq(UnresolvedAttribute("TbL.a")),
-          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
-        Project(testRelation.output, testRelation))
-
-    assert(
-      caseInsensitiveAnalyzer.execute(
-        Project(Seq(UnresolvedAttribute("tBl.a")),
-          UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
-        Project(testRelation.output, testRelation))
+    checkAnalysis(
+      Project(Seq(UnresolvedAttribute("a")), testRelation),
+      Project(testRelation.output, testRelation))
+
+    checkAnalysis(
+      Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+      Project(testRelation.output, testRelation))
+
+    assertAnalysisError(
+      Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+      Seq("cannot resolve"))
+
+    checkAnalysis(
+      Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+      Project(testRelation.output, testRelation),
+      caseSensitive = false)
+
+    checkAnalysis(
+      Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL"))),
+      Project(testRelation.output, testRelation),
+      caseSensitive = false)
   }
 
   test("resolve relations") {
-    val e = intercept[RuntimeException] {
-      caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
-    }
-    assert(e.getMessage == "Table Not Found: tAbLe")
+    assertAnalysisError(UnresolvedRelation(Seq("tAbLe"), None), Seq("Table Not Found: tAbLe"))
 
-    assert(
-      caseSensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+    checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation)
 
-    assert(
-      caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
+    checkAnalysis(UnresolvedRelation(Seq("tAbLe"), None), testRelation, caseSensitive = false)
 
-    assert(
-      caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
+    checkAnalysis(UnresolvedRelation(Seq("TaBlE"), None), testRelation, caseSensitive = false)
   }
 
-
   test("divide should be casted into fractional types") {
-    val testRelation2 = LocalRelation(
-      AttributeReference("a", StringType)(),
-      AttributeReference("b", StringType)(),
-      AttributeReference("c", DoubleType)(),
-      AttributeReference("d", DecimalType(10, 2))(),
-      AttributeReference("e", ShortType)())
-
     val plan = caseInsensitiveAnalyzer.execute(
       testRelation2.select(
         'a / Literal(2) as 'div1,
@@ -170,10 +145,21 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
         'e / 'e as 'div5))
     val pl = plan.asInstanceOf[Project].projectList
 
+    // StringType will be promoted into Double
     assert(pl(0).dataType == DoubleType)
     assert(pl(1).dataType == DoubleType)
     assert(pl(2).dataType == DoubleType)
-    assert(pl(3).dataType == DoubleType)  // StringType will be promoted into Double
+    assert(pl(3).dataType == DoubleType)
     assert(pl(4).dataType == DoubleType)
   }
+
+  test("pull out nondeterministic expressions from unary LogicalPlan") {
+    val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
+    val projected = Alias(Rand(33), "_nondeterministic")()
+    val expected =
+      Project(testRelation.output,
+        RepartitionByExpression(Seq(projected.toAttribute),
+          Project(testRelation.output :+ projected, testRelation)))
+    checkAnalysis(plan, expected)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
new file mode 100644
index 0000000..fdb4f28
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.types._
+
+trait AnalysisTest extends PlanTest {
+  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
+
+  val testRelation2 = LocalRelation(
+    AttributeReference("a", StringType)(),
+    AttributeReference("b", StringType)(),
+    AttributeReference("c", DoubleType)(),
+    AttributeReference("d", DecimalType(10, 2))(),
+    AttributeReference("e", ShortType)())
+
+  val nestedRelation = LocalRelation(
+    AttributeReference("top", StructType(
+      StructField("duplicateField", StringType) ::
+        StructField("duplicateField", StringType) ::
+        StructField("differentCase", StringType) ::
+        StructField("differentcase", StringType) :: Nil
+    ))())
+
+  val nestedRelation2 = LocalRelation(
+    AttributeReference("top", StructType(
+      StructField("aField", StringType) ::
+        StructField("bField", StringType) ::
+        StructField("cField", StringType) :: Nil
+    ))())
+
+  val listRelation = LocalRelation(
+    AttributeReference("list", ArrayType(IntegerType))())
+
+  val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
+    val caseSensitiveConf = new SimpleCatalystConf(true)
+    val caseInsensitiveConf = new SimpleCatalystConf(false)
+
+    val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
+    val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)
+
+    caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+    caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
+
+    new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
+      override val extendedResolutionRules = EliminateSubQueries :: Nil
+    } ->
+    new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
+      override val extendedResolutionRules = EliminateSubQueries :: Nil
+    }
+  }
+
+  protected def getAnalyzer(caseSensitive: Boolean) = {
+    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
+  }
+
+  protected def checkAnalysis(
+      inputPlan: LogicalPlan,
+      expectedPlan: LogicalPlan,
+      caseSensitive: Boolean = true): Unit = {
+    val analyzer = getAnalyzer(caseSensitive)
+    val actualPlan = analyzer.execute(inputPlan)
+    analyzer.checkAnalysis(actualPlan)
+    comparePlans(actualPlan, expectedPlan)
+  }
+
+  protected def assertAnalysisSuccess(
+      inputPlan: LogicalPlan,
+      caseSensitive: Boolean = true): Unit = {
+    val analyzer = getAnalyzer(caseSensitive)
+    analyzer.checkAnalysis(analyzer.execute(inputPlan))
+  }
+
+  protected def assertAnalysisError(
+      inputPlan: LogicalPlan,
+      expectedErrors: Seq[String],
+      caseSensitive: Boolean = true): Unit = {
+    val analyzer = getAnalyzer(caseSensitive)
+    // todo: make sure we throw AnalysisException during analysis
+    val e = intercept[Exception] {
+      analyzer.checkAnalysis(analyzer.execute(inputPlan))
+    }
+    expectedErrors.forall(e.getMessage.contains)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 4930219..852a8b2 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -64,6 +64,10 @@ trait ExpressionEvalHelper {
   }
 
   protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
+    expression.foreach {
+      case n: Nondeterministic => n.initialize()
+      case _ =>
+    }
     expression.eval(inputRow)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
index 2645eb1..eca36b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala
@@ -37,17 +37,22 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with
 
   /**
    * Record ID within each partition. By being transient, count's value is reset to 0 every time
-   * we serialize and deserialize it.
+   * we serialize and deserialize and initialize it.
    */
-  @transient private[this] var count: Long = 0L
+  @transient private[this] var count: Long = _
 
-  @transient private lazy val partitionMask = TaskContext.getPartitionId().toLong << 33
+  @transient private[this] var partitionMask: Long = _
+
+  override protected def initInternal(): Unit = {
+    count = 0L
+    partitionMask = TaskContext.getPartitionId().toLong << 33
+  }
 
   override def nullable: Boolean = false
 
   override def dataType: DataType = LongType
 
-  override def eval(input: InternalRow): Long = {
+  override protected def evalInternal(input: InternalRow): Long = {
     val currentCount = count
     count += 1
     partitionMask + currentCount

http://git-wip-us.apache.org/repos/asf/spark/blob/2c94d0f2/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
index 53ddd47..61ef079 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/SparkPartitionID.scala
@@ -33,9 +33,13 @@ private[sql] case object SparkPartitionID extends LeafExpression with Nondetermi
 
   override def dataType: DataType = IntegerType
 
-  @transient private lazy val partitionId = TaskContext.getPartitionId()
+  @transient private[this] var partitionId: Int = _
 
-  override def eval(input: InternalRow): Int = partitionId
+  override protected def initInternal(): Unit = {
+    partitionId = TaskContext.getPartitionId()
+  }
+
+  override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
     val idTerm = ctx.freshName("partitionId")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org