You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/11/10 03:11:35 UTC

spark git commit: [SPARK-22308][TEST-MAVEN] Support alternative unit testing styles in external applications

Repository: spark
Updated Branches:
  refs/heads/master f5fe63f7b -> b57ed2245


[SPARK-22308][TEST-MAVEN] Support alternative unit testing styles in external applications

Continuation of PR#19528 (https://github.com/apache/spark/pull/19529#issuecomment-340252119)

The problem with the maven build in the previous PR was the new tests.... the creation of a spark session outside the tests meant there was more than one spark session around at a time.
I was using the spark session outside the tests so that the tests could share data; I've changed it so that each test creates the data anew.

Author: Nathan Kronenfeld <ni...@gmail.com>
Author: Nathan Kronenfeld <nk...@uncharted.software>

Closes #19705 from nkronenfeld/alternative-style-tests-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b57ed224
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b57ed224
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b57ed224

Branch: refs/heads/master
Commit: b57ed2245c705fb0964462cf4492b809ade836c6
Parents: f5fe63f
Author: Nathan Kronenfeld <ni...@gmail.com>
Authored: Thu Nov 9 19:11:30 2017 -0800
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Nov 9 19:11:30 2017 -0800

----------------------------------------------------------------------
 .../org/apache/spark/SharedSparkContext.scala   |  17 +-
 .../spark/sql/catalyst/plans/PlanTest.scala     |  10 +-
 .../spark/sql/test/GenericFlatSpecSuite.scala   |  47 +++++
 .../spark/sql/test/GenericFunSpecSuite.scala    |  49 ++++++
 .../spark/sql/test/GenericWordSpecSuite.scala   |  53 ++++++
 .../apache/spark/sql/test/SQLTestUtils.scala    | 173 ++++++++++---------
 .../spark/sql/test/SharedSQLContext.scala       |  84 +--------
 .../spark/sql/test/SharedSparkSession.scala     | 119 +++++++++++++
 8 files changed, 387 insertions(+), 165 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
index 6aedcb1..1aa1c42 100644
--- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel
 
   var conf = new SparkConf(false)
 
+  /**
+   * Initialize the [[SparkContext]].  Generally, this is just called from beforeAll; however, in
+   * test using styles other than FunSuite, there is often code that relies on the session between
+   * test group constructs and the actual tests, which may need this session.  It is purely a
+   * semantic difference, but semantically, it makes more sense to call 'initializeContext' between
+   * a 'describe' and an 'it' call than it does to call 'beforeAll'.
+   */
+  protected def initializeContext(): Unit = {
+    if (null == _sc) {
+      _sc = new SparkContext(
+        "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
+    }
+  }
+
   override def beforeAll() {
     super.beforeAll()
-    _sc = new SparkContext(
-      "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
+    initializeContext()
   }
 
   override def afterAll() {

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 10bdfaf..82c5307 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import org.scalatest.Suite
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
@@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf
 /**
  * Provides helper methods for comparing plans.
  */
-trait PlanTest extends SparkFunSuite with PredicateHelper {
+trait PlanTest extends SparkFunSuite with PlanTestBase
+
+/**
+ * Provides helper methods for comparing plans, but without the overhead of
+ * mandating a FunSuite.
+ */
+trait PlanTestBase extends PredicateHelper { self: Suite =>
 
   // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules
   protected def conf = SQLConf.get

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
new file mode 100644
index 0000000..14ac479
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.FlatSpec
+
+import org.apache.spark.sql.Dataset
+
+/**
+ * The purpose of this suite is to make sure that generic FlatSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession {
+  import testImplicits._
+
+  private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  "A Simple Dataset" should "have the specified number of elements" in {
+    assert(8 === ds.count)
+  }
+  it should "have the specified number of unique elements" in {
+      assert(8 === ds.distinct.count)
+  }
+  it should "have the specified number of elements in each column" in {
+    assert(8 === ds.select("_1").count)
+    assert(8 === ds.select("_2").count)
+  }
+  it should "have the correct number of distinct elements in each column" in {
+    assert(8 === ds.select("_1").distinct.count)
+    assert(4 === ds.select("_2").distinct.count)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
new file mode 100644
index 0000000..e8971e3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.FunSpec
+
+import org.apache.spark.sql.Dataset
+
+/**
+ * The purpose of this suite is to make sure that generic FunSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericFunSpecSuite extends FunSpec with SharedSparkSession {
+  import testImplicits._
+
+  private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  describe("Simple Dataset") {
+    it("should have the specified number of elements") {
+      assert(8 === ds.count)
+    }
+    it("should have the specified number of unique elements") {
+      assert(8 === ds.distinct.count)
+    }
+    it("should have the specified number of elements in each column") {
+      assert(8 === ds.select("_1").count)
+      assert(8 === ds.select("_2").count)
+    }
+    it("should have the correct number of distinct elements in each column") {
+      assert(8 === ds.select("_1").distinct.count)
+      assert(4 === ds.select("_2").distinct.count)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
new file mode 100644
index 0000000..44655a5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.WordSpec
+
+import org.apache.spark.sql.Dataset
+
+/**
+ * The purpose of this suite is to make sure that generic WordSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericWordSpecSuite extends WordSpec with SharedSparkSession {
+  import testImplicits._
+
+  private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  "A Simple Dataset" when {
+    "looked at as complete rows" should {
+      "have the specified number of elements" in {
+        assert(8 === ds.count)
+      }
+      "have the specified number of unique elements" in {
+        assert(8 === ds.distinct.count)
+      }
+    }
+    "refined to specific columns" should {
+      "have the specified number of elements in each column" in {
+        assert(8 === ds.select("_1").count)
+        assert(8 === ds.select("_2").count)
+      }
+      "have the correct number of distinct elements in each column" in {
+        assert(8 === ds.select("_1").distinct.count)
+        assert(4 === ds.select("_2").distinct.count)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index a14a144..b4248b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,7 +27,7 @@ import scala.language.implicitConversions
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.Path
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
 import org.scalatest.concurrent.Eventually
 
 import org.apache.spark.SparkFunSuite
@@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
 import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.FilterExec
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.{UninterruptibleThread, Utils}
+import org.apache.spark.util.UninterruptibleThread
+import org.apache.spark.util.Utils
 
 /**
- * Helper trait that should be extended by all SQL test suites.
+ * Helper trait that should be extended by all SQL test suites within the Spark
+ * code base.
  *
  * This allows subclasses to plugin a custom `SQLContext`. It comes with test data
  * prepared in advance as well as all implicit conversions used extensively by dataframes.
@@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils}
  * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is
  * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
  */
-private[sql] trait SQLTestUtils
-  extends SparkFunSuite with Eventually
+private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest {
+  // Whether to materialize all test data before the first test is run
+  private var loadTestDataBeforeTests = false
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    if (loadTestDataBeforeTests) {
+      loadTestData()
+    }
+  }
+
+  /**
+   * Materialize the test data immediately after the `SQLContext` is set up.
+   * This is necessary if the data is accessed by name but not through direct reference.
+   */
+  protected def setupTestData(): Unit = {
+    loadTestDataBeforeTests = true
+  }
+
+  /**
+   * Disable stdout and stderr when running the test. To not output the logs to the console,
+   * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of
+   * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if
+   * we change System.out and System.err.
+   */
+  protected def testQuietly(name: String)(f: => Unit): Unit = {
+    test(name) {
+      quietly {
+        f
+      }
+    }
+  }
+
+  /**
+   * Run a test on a separate `UninterruptibleThread`.
+   */
+  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false)
+    (body: => Unit): Unit = {
+    val timeoutMillis = 10000
+    @transient var ex: Throwable = null
+
+    def runOnThread(): Unit = {
+      val thread = new UninterruptibleThread(s"Testing thread for test $name") {
+        override def run(): Unit = {
+          try {
+            body
+          } catch {
+            case NonFatal(e) =>
+              ex = e
+          }
+        }
+      }
+      thread.setDaemon(true)
+      thread.start()
+      thread.join(timeoutMillis)
+      if (thread.isAlive) {
+        thread.interrupt()
+        // If this interrupt does not work, then this thread is most likely running something that
+        // is not interruptible. There is not much point to wait for the thread to termniate, and
+        // we rather let the JVM terminate the thread on exit.
+        fail(
+          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" +
+            s" $timeoutMillis ms")
+      } else if (ex != null) {
+        throw ex
+      }
+    }
+
+    if (quietly) {
+      testQuietly(name) { runOnThread() }
+    } else {
+      test(name) { runOnThread() }
+    }
+  }
+}
+
+/**
+ * Helper trait that can be extended by all external SQL test suites.
+ *
+ * This allows subclasses to plugin a custom `SQLContext`.
+ * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`.
+ *
+ * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is
+ * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
+ */
+private[sql] trait SQLTestUtilsBase
+  extends Eventually
   with BeforeAndAfterAll
   with SQLTestData
-  with PlanTest { self =>
+  with PlanTestBase { self: Suite =>
 
   protected def sparkContext = spark.sparkContext
 
-  // Whether to materialize all test data before the first test is run
-  private var loadTestDataBeforeTests = false
-
   // Shorthand for running a query using our SQLContext
   protected lazy val sql = spark.sql _
 
@@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils
     protected override def _sqlContext: SQLContext = self.spark.sqlContext
   }
 
-  /**
-   * Materialize the test data immediately after the `SQLContext` is set up.
-   * This is necessary if the data is accessed by name but not through direct reference.
-   */
-  protected def setupTestData(): Unit = {
-    loadTestDataBeforeTests = true
-  }
-
-  protected override def beforeAll(): Unit = {
-    super.beforeAll()
-    if (loadTestDataBeforeTests) {
-      loadTestData()
-    }
-  }
-
   protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
     SparkSession.setActiveSession(spark)
     super.withSQLConf(pairs: _*)(f)
@@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils
     Dataset.ofRows(spark, plan)
   }
 
-  /**
-   * Disable stdout and stderr when running the test. To not output the logs to the console,
-   * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of
-   * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if
-   * we change System.out and System.err.
-   */
-  protected def testQuietly(name: String)(f: => Unit): Unit = {
-    test(name) {
-      quietly {
-        f
-      }
-    }
-  }
-
-  /**
-   * Run a test on a separate `UninterruptibleThread`.
-   */
-  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false)
-    (body: => Unit): Unit = {
-    val timeoutMillis = 10000
-    @transient var ex: Throwable = null
-
-    def runOnThread(): Unit = {
-      val thread = new UninterruptibleThread(s"Testing thread for test $name") {
-        override def run(): Unit = {
-          try {
-            body
-          } catch {
-            case NonFatal(e) =>
-              ex = e
-          }
-        }
-      }
-      thread.setDaemon(true)
-      thread.start()
-      thread.join(timeoutMillis)
-      if (thread.isAlive) {
-        thread.interrupt()
-        // If this interrupt does not work, then this thread is most likely running something that
-        // is not interruptible. There is not much point to wait for the thread to termniate, and
-        // we rather let the JVM terminate the thread on exit.
-        fail(
-          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" +
-            s" $timeoutMillis ms")
-      } else if (ex != null) {
-        throw ex
-      }
-    }
-
-    if (quietly) {
-      testQuietly(name) { runOnThread() }
-    } else {
-      test(name) { runOnThread() }
-    }
-  }
 
   /**
    * This method is used to make the given path qualified, when a path

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index cd8d070..4d578e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -17,86 +17,4 @@
 
 package org.apache.spark.sql.test
 
-import scala.concurrent.duration._
-
-import org.scalatest.BeforeAndAfterEach
-import org.scalatest.concurrent.Eventually
-
-import org.apache.spark.{DebugFilesystem, SparkConf}
-import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.internal.SQLConf
-
-/**
- * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
- */
-trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually {
-
-  protected def sparkConf = {
-    new SparkConf()
-      .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
-      .set("spark.unsafe.exceptionOnMemoryLeak", "true")
-      .set(SQLConf.CODEGEN_FALLBACK.key, "false")
-  }
-
-  /**
-   * The [[TestSparkSession]] to use for all tests in this suite.
-   *
-   * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
-   * mode with the default test configurations.
-   */
-  private var _spark: TestSparkSession = null
-
-  /**
-   * The [[TestSparkSession]] to use for all tests in this suite.
-   */
-  protected implicit def spark: SparkSession = _spark
-
-  /**
-   * The [[TestSQLContext]] to use for all tests in this suite.
-   */
-  protected implicit def sqlContext: SQLContext = _spark.sqlContext
-
-  protected def createSparkSession: TestSparkSession = {
-    new TestSparkSession(sparkConf)
-  }
-
-  /**
-   * Initialize the [[TestSparkSession]].
-   */
-  protected override def beforeAll(): Unit = {
-    SparkSession.sqlListener.set(null)
-    if (_spark == null) {
-      _spark = createSparkSession
-    }
-    // Ensure we have initialized the context before calling parent code
-    super.beforeAll()
-  }
-
-  /**
-   * Stop the underlying [[org.apache.spark.SparkContext]], if any.
-   */
-  protected override def afterAll(): Unit = {
-    super.afterAll()
-    if (_spark != null) {
-      _spark.sessionState.catalog.reset()
-      _spark.stop()
-      _spark = null
-    }
-  }
-
-  protected override def beforeEach(): Unit = {
-    super.beforeEach()
-    DebugFilesystem.clearOpenStreams()
-  }
-
-  protected override def afterEach(): Unit = {
-    super.afterEach()
-    // Clear all persistent datasets after each test
-    spark.sharedState.cacheManager.clearCache()
-    // files can be closed from other threads, so wait a bit
-    // normally this doesn't take more than 1s
-    eventually(timeout(10.seconds)) {
-      DebugFilesystem.assertNoOpenStreams()
-    }
-  }
-}
+trait SharedSQLContext extends SQLTestUtils with SharedSparkSession

http://git-wip-us.apache.org/repos/asf/spark/blob/b57ed224/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
new file mode 100644
index 0000000..e0568a3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import scala.concurrent.duration._
+
+import org.scalatest.{BeforeAndAfterEach, Suite}
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{DebugFilesystem, SparkConf}
+import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
+ */
+trait SharedSparkSession
+  extends SQLTestUtilsBase
+  with BeforeAndAfterEach
+  with Eventually { self: Suite =>
+
+  protected def sparkConf = {
+    new SparkConf()
+      .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+      .set("spark.unsafe.exceptionOnMemoryLeak", "true")
+      .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+  }
+
+  /**
+   * The [[TestSparkSession]] to use for all tests in this suite.
+   *
+   * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
+   * mode with the default test configurations.
+   */
+  private var _spark: TestSparkSession = null
+
+  /**
+   * The [[TestSparkSession]] to use for all tests in this suite.
+   */
+  protected implicit def spark: SparkSession = _spark
+
+  /**
+   * The [[TestSQLContext]] to use for all tests in this suite.
+   */
+  protected implicit def sqlContext: SQLContext = _spark.sqlContext
+
+  protected def createSparkSession: TestSparkSession = {
+    new TestSparkSession(sparkConf)
+  }
+
+  /**
+   * Initialize the [[TestSparkSession]].  Generally, this is just called from
+   * beforeAll; however, in test using styles other than FunSuite, there is
+   * often code that relies on the session between test group constructs and
+   * the actual tests, which may need this session.  It is purely a semantic
+   * difference, but semantically, it makes more sense to call
+   * 'initializeSession' between a 'describe' and an 'it' call than it does to
+   * call 'beforeAll'.
+   */
+  protected def initializeSession(): Unit = {
+    SparkSession.sqlListener.set(null)
+    if (_spark == null) {
+      _spark = createSparkSession
+    }
+  }
+
+  /**
+   * Make sure the [[TestSparkSession]] is initialized before any tests are run.
+   */
+  protected override def beforeAll(): Unit = {
+    initializeSession()
+
+    // Ensure we have initialized the context before calling parent code
+    super.beforeAll()
+  }
+
+  /**
+   * Stop the underlying [[org.apache.spark.SparkContext]], if any.
+   */
+  protected override def afterAll(): Unit = {
+    super.afterAll()
+    if (_spark != null) {
+      _spark.sessionState.catalog.reset()
+      _spark.stop()
+      _spark = null
+    }
+  }
+
+  protected override def beforeEach(): Unit = {
+    super.beforeEach()
+    DebugFilesystem.clearOpenStreams()
+  }
+
+  protected override def afterEach(): Unit = {
+    super.afterEach()
+    // Clear all persistent datasets after each test
+    spark.sharedState.cacheManager.clearCache()
+    // files can be closed from other threads, so wait a bit
+    // normally this doesn't take more than 1s
+    eventually(timeout(10.seconds)) {
+      DebugFilesystem.assertNoOpenStreams()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org