You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by li...@apache.org on 2017/10/26 07:29:52 UTC

spark git commit: [SPARK-22308] Support alternative unit testing styles in external applications

Repository: spark
Updated Branches:
  refs/heads/master 5433be44c -> 592cfeab9


[SPARK-22308] Support alternative unit testing styles in external applications

## What changes were proposed in this pull request?
Support unit tests of external code (i.e., applications that use spark) using scalatest that don't want to use FunSuite.  SharedSparkContext already supports this, but SharedSQLContext does not.

I've introduced SharedSparkSession as a parent to SharedSQLContext, written in a way that it does support all scalatest styles.

## How was this patch tested?
There are three new unit test suites added that just test using FunSpec, FlatSpec, and WordSpec.

Author: Nathan Kronenfeld <ni...@gmail.com>

Closes #19529 from nkronenfeld/alternative-style-tests-2.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/592cfeab
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/592cfeab
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/592cfeab

Branch: refs/heads/master
Commit: 592cfeab9caeff955d115a1ca5014ede7d402907
Parents: 5433be4
Author: Nathan Kronenfeld <ni...@gmail.com>
Authored: Thu Oct 26 00:29:49 2017 -0700
Committer: gatorsmile <ga...@gmail.com>
Committed: Thu Oct 26 00:29:49 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/SharedSparkContext.scala   |  17 +-
 .../spark/sql/catalyst/plans/PlanTest.scala     |  10 +-
 .../spark/sql/test/GenericFlatSpecSuite.scala   |  45 +++++
 .../spark/sql/test/GenericFunSpecSuite.scala    |  47 +++++
 .../spark/sql/test/GenericWordSpecSuite.scala   |  51 ++++++
 .../apache/spark/sql/test/SQLTestUtils.scala    | 173 ++++++++++---------
 .../spark/sql/test/SharedSQLContext.scala       |  84 +--------
 .../spark/sql/test/SharedSparkSession.scala     | 119 +++++++++++++
 8 files changed, 381 insertions(+), 165 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
index 6aedcb1..1aa1c42 100644
--- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala
@@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel
 
   var conf = new SparkConf(false)
 
+  /**
+   * Initialize the [[SparkContext]].  Generally, this is just called from beforeAll; however, in
+   * test using styles other than FunSuite, there is often code that relies on the session between
+   * test group constructs and the actual tests, which may need this session.  It is purely a
+   * semantic difference, but semantically, it makes more sense to call 'initializeContext' between
+   * a 'describe' and an 'it' call than it does to call 'beforeAll'.
+   */
+  protected def initializeContext(): Unit = {
+    if (null == _sc) {
+      _sc = new SparkContext(
+        "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
+    }
+  }
+
   override def beforeAll() {
     super.beforeAll()
-    _sc = new SparkContext(
-      "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName))
+    initializeContext()
   }
 
   override def afterAll() {

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 10bdfaf..82c5307 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.catalyst.plans
 
+import org.scalatest.Suite
+
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
@@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf
 /**
  * Provides helper methods for comparing plans.
  */
-trait PlanTest extends SparkFunSuite with PredicateHelper {
+trait PlanTest extends SparkFunSuite with PlanTestBase
+
+/**
+ * Provides helper methods for comparing plans, but without the overhead of
+ * mandating a FunSuite.
+ */
+trait PlanTestBase extends PredicateHelper { self: Suite =>
 
   // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules
   protected def conf = SQLConf.get

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
new file mode 100644
index 0000000..6179585
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.FlatSpec
+
+/**
+ * The purpose of this suite is to make sure that generic FlatSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession {
+  import testImplicits._
+  initializeSession()
+  val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  "A Simple Dataset" should "have the specified number of elements" in {
+    assert(8 === ds.count)
+  }
+  it should "have the specified number of unique elements" in {
+      assert(8 === ds.distinct.count)
+  }
+  it should "have the specified number of elements in each column" in {
+    assert(8 === ds.select("_1").count)
+    assert(8 === ds.select("_2").count)
+  }
+  it should "have the correct number of distinct elements in each column" in {
+    assert(8 === ds.select("_1").distinct.count)
+    assert(4 === ds.select("_2").distinct.count)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
new file mode 100644
index 0000000..15139ee
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.FunSpec
+
+/**
+ * The purpose of this suite is to make sure that generic FunSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericFunSpecSuite extends FunSpec with SharedSparkSession {
+  import testImplicits._
+  initializeSession()
+  val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  describe("Simple Dataset") {
+    it("should have the specified number of elements") {
+      assert(8 === ds.count)
+    }
+    it("should have the specified number of unique elements") {
+      assert(8 === ds.distinct.count)
+    }
+    it("should have the specified number of elements in each column") {
+      assert(8 === ds.select("_1").count)
+      assert(8 === ds.select("_2").count)
+    }
+    it("should have the correct number of distinct elements in each column") {
+      assert(8 === ds.select("_1").distinct.count)
+      assert(4 === ds.select("_2").distinct.count)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
new file mode 100644
index 0000000..b6548bf
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.scalatest.WordSpec
+
+/**
+ * The purpose of this suite is to make sure that generic WordSpec-based scala
+ * tests work with a shared spark session
+ */
+class GenericWordSpecSuite extends WordSpec with SharedSparkSession {
+  import testImplicits._
+  initializeSession()
+  val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS
+
+  "A Simple Dataset" when {
+    "looked at as complete rows" should {
+      "have the specified number of elements" in {
+        assert(8 === ds.count)
+      }
+      "have the specified number of unique elements" in {
+        assert(8 === ds.distinct.count)
+      }
+    }
+    "refined to specific columns" should {
+      "have the specified number of elements in each column" in {
+        assert(8 === ds.select("_1").count)
+        assert(8 === ds.select("_2").count)
+      }
+      "have the correct number of distinct elements in each column" in {
+        assert(8 === ds.select("_1").distinct.count)
+        assert(4 === ds.select("_2").distinct.count)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index a14a144..b4248b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,7 +27,7 @@ import scala.language.implicitConversions
 import scala.util.control.NonFatal
 
 import org.apache.hadoop.fs.Path
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.{BeforeAndAfterAll, Suite}
 import org.scalatest.concurrent.Eventually
 
 import org.apache.spark.SparkFunSuite
@@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE
 import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.PlanTestBase
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.FilterExec
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.{UninterruptibleThread, Utils}
+import org.apache.spark.util.UninterruptibleThread
+import org.apache.spark.util.Utils
 
 /**
- * Helper trait that should be extended by all SQL test suites.
+ * Helper trait that should be extended by all SQL test suites within the Spark
+ * code base.
  *
  * This allows subclasses to plugin a custom `SQLContext`. It comes with test data
  * prepared in advance as well as all implicit conversions used extensively by dataframes.
@@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils}
  * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is
  * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
  */
-private[sql] trait SQLTestUtils
-  extends SparkFunSuite with Eventually
+private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest {
+  // Whether to materialize all test data before the first test is run
+  private var loadTestDataBeforeTests = false
+
+  protected override def beforeAll(): Unit = {
+    super.beforeAll()
+    if (loadTestDataBeforeTests) {
+      loadTestData()
+    }
+  }
+
+  /**
+   * Materialize the test data immediately after the `SQLContext` is set up.
+   * This is necessary if the data is accessed by name but not through direct reference.
+   */
+  protected def setupTestData(): Unit = {
+    loadTestDataBeforeTests = true
+  }
+
+  /**
+   * Disable stdout and stderr when running the test. To not output the logs to the console,
+   * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of
+   * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if
+   * we change System.out and System.err.
+   */
+  protected def testQuietly(name: String)(f: => Unit): Unit = {
+    test(name) {
+      quietly {
+        f
+      }
+    }
+  }
+
+  /**
+   * Run a test on a separate `UninterruptibleThread`.
+   */
+  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false)
+    (body: => Unit): Unit = {
+    val timeoutMillis = 10000
+    @transient var ex: Throwable = null
+
+    def runOnThread(): Unit = {
+      val thread = new UninterruptibleThread(s"Testing thread for test $name") {
+        override def run(): Unit = {
+          try {
+            body
+          } catch {
+            case NonFatal(e) =>
+              ex = e
+          }
+        }
+      }
+      thread.setDaemon(true)
+      thread.start()
+      thread.join(timeoutMillis)
+      if (thread.isAlive) {
+        thread.interrupt()
+        // If this interrupt does not work, then this thread is most likely running something that
+        // is not interruptible. There is not much point to wait for the thread to termniate, and
+        // we rather let the JVM terminate the thread on exit.
+        fail(
+          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" +
+            s" $timeoutMillis ms")
+      } else if (ex != null) {
+        throw ex
+      }
+    }
+
+    if (quietly) {
+      testQuietly(name) { runOnThread() }
+    } else {
+      test(name) { runOnThread() }
+    }
+  }
+}
+
+/**
+ * Helper trait that can be extended by all external SQL test suites.
+ *
+ * This allows subclasses to plugin a custom `SQLContext`.
+ * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`.
+ *
+ * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is
+ * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
+ */
+private[sql] trait SQLTestUtilsBase
+  extends Eventually
   with BeforeAndAfterAll
   with SQLTestData
-  with PlanTest { self =>
+  with PlanTestBase { self: Suite =>
 
   protected def sparkContext = spark.sparkContext
 
-  // Whether to materialize all test data before the first test is run
-  private var loadTestDataBeforeTests = false
-
   // Shorthand for running a query using our SQLContext
   protected lazy val sql = spark.sql _
 
@@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils
     protected override def _sqlContext: SQLContext = self.spark.sqlContext
   }
 
-  /**
-   * Materialize the test data immediately after the `SQLContext` is set up.
-   * This is necessary if the data is accessed by name but not through direct reference.
-   */
-  protected def setupTestData(): Unit = {
-    loadTestDataBeforeTests = true
-  }
-
-  protected override def beforeAll(): Unit = {
-    super.beforeAll()
-    if (loadTestDataBeforeTests) {
-      loadTestData()
-    }
-  }
-
   protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
     SparkSession.setActiveSession(spark)
     super.withSQLConf(pairs: _*)(f)
@@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils
     Dataset.ofRows(spark, plan)
   }
 
-  /**
-   * Disable stdout and stderr when running the test. To not output the logs to the console,
-   * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of
-   * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if
-   * we change System.out and System.err.
-   */
-  protected def testQuietly(name: String)(f: => Unit): Unit = {
-    test(name) {
-      quietly {
-        f
-      }
-    }
-  }
-
-  /**
-   * Run a test on a separate `UninterruptibleThread`.
-   */
-  protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false)
-    (body: => Unit): Unit = {
-    val timeoutMillis = 10000
-    @transient var ex: Throwable = null
-
-    def runOnThread(): Unit = {
-      val thread = new UninterruptibleThread(s"Testing thread for test $name") {
-        override def run(): Unit = {
-          try {
-            body
-          } catch {
-            case NonFatal(e) =>
-              ex = e
-          }
-        }
-      }
-      thread.setDaemon(true)
-      thread.start()
-      thread.join(timeoutMillis)
-      if (thread.isAlive) {
-        thread.interrupt()
-        // If this interrupt does not work, then this thread is most likely running something that
-        // is not interruptible. There is not much point to wait for the thread to termniate, and
-        // we rather let the JVM terminate the thread on exit.
-        fail(
-          s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" +
-            s" $timeoutMillis ms")
-      } else if (ex != null) {
-        throw ex
-      }
-    }
-
-    if (quietly) {
-      testQuietly(name) { runOnThread() }
-    } else {
-      test(name) { runOnThread() }
-    }
-  }
 
   /**
    * This method is used to make the given path qualified, when a path

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
index cd8d070..4d578e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -17,86 +17,4 @@
 
 package org.apache.spark.sql.test
 
-import scala.concurrent.duration._
-
-import org.scalatest.BeforeAndAfterEach
-import org.scalatest.concurrent.Eventually
-
-import org.apache.spark.{DebugFilesystem, SparkConf}
-import org.apache.spark.sql.{SparkSession, SQLContext}
-import org.apache.spark.sql.internal.SQLConf
-
-/**
- * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
- */
-trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually {
-
-  protected def sparkConf = {
-    new SparkConf()
-      .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
-      .set("spark.unsafe.exceptionOnMemoryLeak", "true")
-      .set(SQLConf.CODEGEN_FALLBACK.key, "false")
-  }
-
-  /**
-   * The [[TestSparkSession]] to use for all tests in this suite.
-   *
-   * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
-   * mode with the default test configurations.
-   */
-  private var _spark: TestSparkSession = null
-
-  /**
-   * The [[TestSparkSession]] to use for all tests in this suite.
-   */
-  protected implicit def spark: SparkSession = _spark
-
-  /**
-   * The [[TestSQLContext]] to use for all tests in this suite.
-   */
-  protected implicit def sqlContext: SQLContext = _spark.sqlContext
-
-  protected def createSparkSession: TestSparkSession = {
-    new TestSparkSession(sparkConf)
-  }
-
-  /**
-   * Initialize the [[TestSparkSession]].
-   */
-  protected override def beforeAll(): Unit = {
-    SparkSession.sqlListener.set(null)
-    if (_spark == null) {
-      _spark = createSparkSession
-    }
-    // Ensure we have initialized the context before calling parent code
-    super.beforeAll()
-  }
-
-  /**
-   * Stop the underlying [[org.apache.spark.SparkContext]], if any.
-   */
-  protected override def afterAll(): Unit = {
-    super.afterAll()
-    if (_spark != null) {
-      _spark.sessionState.catalog.reset()
-      _spark.stop()
-      _spark = null
-    }
-  }
-
-  protected override def beforeEach(): Unit = {
-    super.beforeEach()
-    DebugFilesystem.clearOpenStreams()
-  }
-
-  protected override def afterEach(): Unit = {
-    super.afterEach()
-    // Clear all persistent datasets after each test
-    spark.sharedState.cacheManager.clearCache()
-    // files can be closed from other threads, so wait a bit
-    // normally this doesn't take more than 1s
-    eventually(timeout(10.seconds)) {
-      DebugFilesystem.assertNoOpenStreams()
-    }
-  }
-}
+trait SharedSQLContext extends SQLTestUtils with SharedSparkSession

http://git-wip-us.apache.org/repos/asf/spark/blob/592cfeab/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
new file mode 100644
index 0000000..e0568a3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import scala.concurrent.duration._
+
+import org.scalatest.{BeforeAndAfterEach, Suite}
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{DebugFilesystem, SparkConf}
+import org.apache.spark.sql.{SparkSession, SQLContext}
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]].
+ */
+trait SharedSparkSession
+  extends SQLTestUtilsBase
+  with BeforeAndAfterEach
+  with Eventually { self: Suite =>
+
+  protected def sparkConf = {
+    new SparkConf()
+      .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
+      .set("spark.unsafe.exceptionOnMemoryLeak", "true")
+      .set(SQLConf.CODEGEN_FALLBACK.key, "false")
+  }
+
+  /**
+   * The [[TestSparkSession]] to use for all tests in this suite.
+   *
+   * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
+   * mode with the default test configurations.
+   */
+  private var _spark: TestSparkSession = null
+
+  /**
+   * The [[TestSparkSession]] to use for all tests in this suite.
+   */
+  protected implicit def spark: SparkSession = _spark
+
+  /**
+   * The [[TestSQLContext]] to use for all tests in this suite.
+   */
+  protected implicit def sqlContext: SQLContext = _spark.sqlContext
+
+  protected def createSparkSession: TestSparkSession = {
+    new TestSparkSession(sparkConf)
+  }
+
+  /**
+   * Initialize the [[TestSparkSession]].  Generally, this is just called from
+   * beforeAll; however, in test using styles other than FunSuite, there is
+   * often code that relies on the session between test group constructs and
+   * the actual tests, which may need this session.  It is purely a semantic
+   * difference, but semantically, it makes more sense to call
+   * 'initializeSession' between a 'describe' and an 'it' call than it does to
+   * call 'beforeAll'.
+   */
+  protected def initializeSession(): Unit = {
+    SparkSession.sqlListener.set(null)
+    if (_spark == null) {
+      _spark = createSparkSession
+    }
+  }
+
+  /**
+   * Make sure the [[TestSparkSession]] is initialized before any tests are run.
+   */
+  protected override def beforeAll(): Unit = {
+    initializeSession()
+
+    // Ensure we have initialized the context before calling parent code
+    super.beforeAll()
+  }
+
+  /**
+   * Stop the underlying [[org.apache.spark.SparkContext]], if any.
+   */
+  protected override def afterAll(): Unit = {
+    super.afterAll()
+    if (_spark != null) {
+      _spark.sessionState.catalog.reset()
+      _spark.stop()
+      _spark = null
+    }
+  }
+
+  protected override def beforeEach(): Unit = {
+    super.beforeEach()
+    DebugFilesystem.clearOpenStreams()
+  }
+
+  protected override def afterEach(): Unit = {
+    super.afterEach()
+    // Clear all persistent datasets after each test
+    spark.sharedState.cacheManager.clearCache()
+    // files can be closed from other threads, so wait a bit
+    // normally this doesn't take more than 1s
+    eventually(timeout(10.seconds)) {
+      DebugFilesystem.assertNoOpenStreams()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org