You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/04/07 17:38:02 UTC

[2/2] incubator-systemml git commit: [SYSTEMML-1347] Accept SparkSession in Java/Scala MLContext API

[SYSTEMML-1347] Accept SparkSession in Java/Scala MLContext API

Add MLContext constructor for SparkSession.
In MLContext, store SparkSession reference instead of SparkContext.
Remove unused monitoring parameter in MLContext.
Simplifications in MLContextUtil and MLContextConversionUtil.
Method for creating SparkSession in AutomatedTestBase.
Update tests for SparkSession.
Add MLContext SparkSession constructor to MLContext guide.

Closes #405.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/9c19b477
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/9c19b477
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/9c19b477

Branch: refs/heads/master
Commit: 9c19b4771caa96af4e959dda363d41e32818fb56
Parents: 9820f4c
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Fri Apr 7 10:35:55 2017 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Fri Apr 7 10:35:55 2017 -0700

----------------------------------------------------------------------
 docs/spark-mlcontext-programming-guide.md       | 14 +--
 .../apache/sysml/api/mlcontext/MLContext.java   | 54 +++++------
 .../api/mlcontext/MLContextConversionUtil.java  | 92 +++++++++---------
 .../sysml/api/mlcontext/MLContextUtil.java      | 45 +++++++--
 .../context/SparkExecutionContext.java          | 16 ++--
 .../test/integration/AutomatedTestBase.java     | 26 ++++++
 .../DataFrameMatrixConversionTest.java          | 43 +++++----
 .../DataFrameRowFrameConversionTest.java        | 49 +++++-----
 .../DataFrameVectorFrameConversionTest.java     | 42 +++++----
 .../mlcontext/DataFrameVectorScriptTest.java    | 62 ++++++-------
 .../functions/mlcontext/FrameTest.java          | 25 ++---
 .../functions/mlcontext/GNMFTest.java           | 21 ++---
 .../mlcontext/MLContextFrameTest.java           | 43 +++------
 .../mlcontext/MLContextMultipleScriptsTest.java | 16 ++--
 .../mlcontext/MLContextScratchCleanupTest.java  | 16 ++--
 .../integration/mlcontext/MLContextTest.java    | 98 +++++++-------------
 16 files changed, 327 insertions(+), 335 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/docs/spark-mlcontext-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/spark-mlcontext-programming-guide.md b/docs/spark-mlcontext-programming-guide.md
index c28eaf5..3b7bfc8 100644
--- a/docs/spark-mlcontext-programming-guide.md
+++ b/docs/spark-mlcontext-programming-guide.md
@@ -47,10 +47,10 @@ spark-shell --executor-memory 4G --driver-memory 4G --jars SystemML.jar
 
 ## Create MLContext
 
-All primary classes that a user interacts with are located in the `org.apache.sysml.api.mlcontext package`.
-For convenience, we can additionally add a static import of ScriptFactory to shorten the syntax for creating Script objects.
-An `MLContext` object can be created by passing its constructor a reference to the `SparkContext`. If successful, you
-should see a "`Welcome to Apache SystemML!`" message.
+All primary classes that a user interacts with are located in the `org.apache.sysml.api.mlcontext` package.
+For convenience, we can additionally add a static import of `ScriptFactory` to shorten the syntax for creating `Script` objects.
+An `MLContext` object can be created by passing its constructor a reference to the `SparkSession` (`spark`) or `SparkContext` (`sc`).
+If successful, you should see a "`Welcome to Apache SystemML!`" message.
 
 <div class="codetabs">
 
@@ -58,7 +58,7 @@ should see a "`Welcome to Apache SystemML!`" message.
 {% highlight scala %}
 import org.apache.sysml.api.mlcontext._
 import org.apache.sysml.api.mlcontext.ScriptFactory._
-val ml = new MLContext(sc)
+val ml = new MLContext(spark)
 {% endhighlight %}
 </div>
 
@@ -70,7 +70,7 @@ import org.apache.sysml.api.mlcontext._
 scala> import org.apache.sysml.api.mlcontext.ScriptFactory._
 import org.apache.sysml.api.mlcontext.ScriptFactory._
 
-scala> val ml = new MLContext(sc)
+scala> val ml = new MLContext(spark)
 
 Welcome to Apache SystemML!
 
@@ -1753,7 +1753,7 @@ Archiver-Version: Plexus Archiver
 Artifact-Id: systemml
 Build-Jdk: 1.8.0_60
 Build-Time: 2017-02-03 22:32:43 UTC
-Built-By: deroneriksson
+Built-By: sparkuser
 Created-By: Apache Maven 3.3.9
 Group-Id: org.apache.systemml
 Main-Class: org.apache.sysml.api.DMLScript

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
index cb98083..41df7fd 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java
@@ -29,6 +29,7 @@ import java.util.Set;
 import org.apache.log4j.Logger;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.MLContextProxy;
@@ -62,9 +63,9 @@ public class MLContext {
 	public static Logger log = Logger.getLogger(MLContext.class);
 
 	/**
-	 * SparkContext object.
+	 * SparkSession object.
 	 */
-	private SparkContext sc = null;
+	private SparkSession spark = null;
 
 	/**
 	 * Reference to the currently executing script.
@@ -164,6 +165,16 @@ public class MLContext {
 	}
 
 	/**
+	 * Create an MLContext based on a SparkSession for interaction with SystemML
+	 * on Spark.
+	 * 
+	 * @param spark SparkSession
+	 */
+	public MLContext(SparkSession spark) {
+		initMLContext(spark);
+	}
+
+	/**
 	 * Create an MLContext based on a SparkContext for interaction with SystemML
 	 * on Spark.
 	 *
@@ -171,7 +182,7 @@ public class MLContext {
 	 *            SparkContext
 	 */
 	public MLContext(SparkContext sparkContext) {
-		this(sparkContext, false);
+		initMLContext(SparkSession.builder().sparkContext(sparkContext).getOrCreate());
 	}
 
 	/**
@@ -182,38 +193,21 @@ public class MLContext {
 	 *            JavaSparkContext
 	 */
 	public MLContext(JavaSparkContext javaSparkContext) {
-		this(javaSparkContext.sc(), false);
-	}
-
-	/**
-	 * Create an MLContext based on a SparkContext for interaction with SystemML
-	 * on Spark, optionally monitor performance.
-	 *
-	 * @param sc
-	 *            SparkContext object.
-	 * @param monitorPerformance
-	 *            {@code true} if performance should be monitored, {@code false}
-	 *            otherwise
-	 */
-	public MLContext(SparkContext sc, boolean monitorPerformance) {
-		initMLContext(sc, monitorPerformance);
+		initMLContext(SparkSession.builder().sparkContext(javaSparkContext.sc()).getOrCreate());
 	}
 
 	/**
 	 * Initialize MLContext. Verify Spark version supported, set default
 	 * execution mode, set MLContextProxy, set default config, set compiler
-	 * config, and configure monitoring if needed.
+	 * config.
 	 *
 	 * @param sc
 	 *            SparkContext object.
-	 * @param monitorPerformance
-	 *            {@code true} if performance should be monitored, {@code false}
-	 *            otherwise
 	 */
-	private void initMLContext(SparkContext sc, boolean monitorPerformance) {
+	private void initMLContext(SparkSession spark) {
 
 		try {
-			MLContextUtil.verifySparkVersionSupported(sc);
+			MLContextUtil.verifySparkVersionSupported(spark);
 		} catch (MLContextException e) {
 			if (info() != null) {
 				log.warn("Apache Spark " + this.info().minimumRecommendedSparkVersion() + " or above is recommended for SystemML " + this.info().version());
@@ -231,7 +225,7 @@ public class MLContext {
 			System.out.println(MLContextUtil.welcomeMessage());
 		}
 
-		this.sc = sc;
+		this.spark = spark;
 		// by default, run in hybrid Spark mode for optimal performance
 		DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
 
@@ -329,12 +323,12 @@ public class MLContext {
 	}
 
 	/**
-	 * Obtain the SparkContext associated with this MLContext.
+	 * Obtain the SparkSession associated with this MLContext.
 	 *
-	 * @return the SparkContext associated with this MLContext.
+	 * @return the SparkSession associated with this MLContext.
 	 */
-	public SparkContext getSparkContext() {
-		return sc;
+	public SparkSession getSparkSession() {
+		return spark;
 	}
 
 	/**
@@ -641,7 +635,7 @@ public class MLContext {
 		scripts.clear();
 		scriptHistoryStrings.clear();
 		resetConfig();
-		sc = null;
+		spark = null;
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
index c496325..dc20108 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java
@@ -40,7 +40,6 @@ import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.api.MLContextProxy;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.parser.Expression.ValueType;
@@ -137,11 +136,7 @@ public class MLContextConversionUtil {
 		try {
 			InputStream is = url.openStream();
 			List<String> lines = IOUtils.readLines(is);
-			MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
-			SparkContext sparkContext = activeMLContext.getSparkContext();
-			@SuppressWarnings("resource")
-			JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext);
-			JavaRDD<String> javaRDD = javaSparkContext.parallelize(lines);
+			JavaRDD<String> javaRDD = jsc().parallelize(lines);
 			if ((matrixMetadata == null) || (matrixMetadata.getMatrixFormat() == MatrixFormat.CSV)) {
 				return javaRDDStringCSVToMatrixObject(variableName, javaRDD, matrixMetadata);
 			} else if (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV) {
@@ -370,8 +365,6 @@ public class MLContextConversionUtil {
 				frameMetadata = new FrameMetadata();
 			determineFrameFormatIfNeeded(dataFrame, frameMetadata);
 			boolean containsID = isDataFrameWithIDColumn(frameMetadata);
-			JavaSparkContext javaSparkContext = MLContextUtil
-					.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
 			MatrixCharacteristics mc = frameMetadata.asMatrixCharacteristics();
 			if( mc == null )
 				mc = new MatrixCharacteristics();
@@ -380,7 +373,7 @@ public class MLContextConversionUtil {
 			//TODO extend frame schema by column names (right now dropped)
 			Pair<String[], ValueType[]> ret = new Pair<String[], ValueType[]>(); 
 			JavaPairRDD<Long, FrameBlock> binaryBlock = FrameRDDConverterUtils
-				.dataFrameToBinaryBlock(javaSparkContext, dataFrame, mc, containsID, ret);
+				.dataFrameToBinaryBlock(jsc(), dataFrame, mc, containsID, ret);
 			frameMetadata.setFrameSchema(new FrameSchema(Arrays.asList(ret.getValue())));
 			frameMetadata.setMatrixCharacteristics(mc); //required due to meta data copy
 			
@@ -426,13 +419,10 @@ public class MLContextConversionUtil {
 				matrixMetadata.asMatrixCharacteristics() : new MatrixCharacteristics();
 		boolean containsID = isDataFrameWithIDColumn(matrixMetadata);
 		boolean isVector = isVectorBasedDataFrame(matrixMetadata);
-	
-		//get spark context
-		JavaSparkContext sc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
 
 		//convert data frame to binary block matrix
 		JavaPairRDD<MatrixIndexes,MatrixBlock> out = RDDConverterUtils
-				.dataFrameToBinaryBlock(sc, dataFrame, mc, containsID, isVector);
+				.dataFrameToBinaryBlock(jsc(), dataFrame, mc, containsID, isVector);
 		
 		//update determined matrix characteristics
 		if( matrixMetadata != null )
@@ -639,14 +629,12 @@ public class MLContextConversionUtil {
 				frameMetadata.asMatrixCharacteristics() : new MatrixCharacteristics();
 		JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction());
 
-		JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
-
 		FrameObject frameObject = new FrameObject(OptimizerUtils.getUniqueTempFileName(), 
 				new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), 
 				frameMetadata.getFrameSchema().getSchema().toArray(new ValueType[0]));
 		JavaPairRDD<Long, FrameBlock> rdd;
 		try {
-			rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc, javaPairRDDText, mc, 
+			rdd = FrameRDDConverterUtils.csvToBinaryBlock(jsc(), javaPairRDDText, mc,
 					frameObject.getSchema(), false, ",", false, -1);
 		} catch (DMLRuntimeException e) {
 			e.printStackTrace();
@@ -701,8 +689,6 @@ public class MLContextConversionUtil {
 
 		JavaPairRDD<LongWritable, Text> javaPairRDDText = javaPairRDD.mapToPair(new CopyTextInputFunction());
 
-		JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
-
 		FrameObject frameObject = new FrameObject(OptimizerUtils.getUniqueTempFileName(), 
 				new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo), 
 				frameMetadata.getFrameSchema().getSchema().toArray(new ValueType[0]));
@@ -711,7 +697,7 @@ public class MLContextConversionUtil {
 			ValueType[] lschema = null;
 			if (lschema == null)
 				lschema = UtilFunctions.nCopies((int) mc.getCols(), ValueType.STRING);
-			rdd = FrameRDDConverterUtils.textCellToBinaryBlock(jsc, javaPairRDDText, mc, lschema);
+			rdd = FrameRDDConverterUtils.textCellToBinaryBlock(jsc(), javaPairRDDText, mc, lschema);
 		} catch (DMLRuntimeException e) {
 			e.printStackTrace();
 			return null;
@@ -859,11 +845,7 @@ public class MLContextConversionUtil {
 	public static JavaRDD<String> matrixObjectToJavaRDDStringCSV(MatrixObject matrixObject) {
 		List<String> list = matrixObjectToListStringCSV(matrixObject);
 
-		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
-		SparkContext sc = activeMLContext.getSparkContext();
-		@SuppressWarnings("resource")
-		JavaSparkContext jsc = new JavaSparkContext(sc);
-		return jsc.parallelize(list);
+		return jsc().parallelize(list);
 	}
 
 	/**
@@ -877,8 +859,7 @@ public class MLContextConversionUtil {
 	public static JavaRDD<String> frameObjectToJavaRDDStringCSV(FrameObject frameObject, String delimiter) {
 		List<String> list = frameObjectToListStringCSV(frameObject, delimiter);
 
-		JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
-		return jsc.parallelize(list);
+		return jsc().parallelize(list);
 	}
 
 	/**
@@ -892,11 +873,7 @@ public class MLContextConversionUtil {
 	public static JavaRDD<String> matrixObjectToJavaRDDStringIJV(MatrixObject matrixObject) {
 		List<String> list = matrixObjectToListStringIJV(matrixObject);
 
-		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
-		SparkContext sc = activeMLContext.getSparkContext();
-		@SuppressWarnings("resource")
-		JavaSparkContext jsc = new JavaSparkContext(sc);
-		return jsc.parallelize(list);
+		return jsc().parallelize(list);
 	}
 
 	/**
@@ -909,8 +886,7 @@ public class MLContextConversionUtil {
 	public static JavaRDD<String> frameObjectToJavaRDDStringIJV(FrameObject frameObject) {
 		List<String> list = frameObjectToListStringIJV(frameObject);
 
-		JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
-		return jsc.parallelize(list);
+		return jsc().parallelize(list);
 	}
 
 	/**
@@ -934,10 +910,8 @@ public class MLContextConversionUtil {
 
 		List<String> list = matrixObjectToListStringIJV(matrixObject);
 
-		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
-		SparkContext sc = activeMLContext.getSparkContext();
 		ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
-		return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+		return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag);
 	}
 
 	/**
@@ -961,9 +935,8 @@ public class MLContextConversionUtil {
 
 		List<String> list = frameObjectToListStringIJV(frameObject);
 
-		SparkContext sc = MLContextUtil.getSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
 		ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
-		return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+		return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag);
 	}
 
 	/**
@@ -987,10 +960,8 @@ public class MLContextConversionUtil {
 
 		List<String> list = matrixObjectToListStringCSV(matrixObject);
 
-		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
-		SparkContext sc = activeMLContext.getSparkContext();
 		ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
-		return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+		return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag);
 	}
 
 	/**
@@ -1015,9 +986,8 @@ public class MLContextConversionUtil {
 
 		List<String> list = frameObjectToListStringCSV(frameObject, delimiter);
 
-		SparkContext sc = MLContextUtil.getSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
 		ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
-		return sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag);
+		return sc().parallelize(JavaConversions.asScalaBuffer(list), sc().defaultParallelism(), tag);
 	}
 
 	/**
@@ -1247,10 +1217,7 @@ public class MLContextConversionUtil {
 					.getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo);
 			MatrixCharacteristics mc = matrixObject.getMatrixCharacteristics();
 
-			SparkContext sc = ((MLContext) MLContextProxy.getActiveMLContextForAPI()).getSparkContext();
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc).getOrCreate();
-			
-			return RDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryBlockMatrix, mc, isVectorDF);
+			return RDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockMatrix, mc, isVectorDF);
 		} 
 		catch (DMLRuntimeException e) {
 			throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e);
@@ -1274,9 +1241,7 @@ public class MLContextConversionUtil {
 					.getRDDHandleForFrameObject(frameObject, InputInfo.BinaryBlockInputInfo);
 			MatrixCharacteristics mc = frameObject.getMatrixCharacteristics();
 
-			JavaSparkContext jsc = MLContextUtil.getJavaSparkContext((MLContext) MLContextProxy.getActiveMLContextForAPI());
-			SparkSession sparkSession = SparkSession.builder().sparkContext(jsc.sc()).getOrCreate();
-			return FrameRDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryBlockFrame, mc, frameObject.getSchema());
+			return FrameRDDConverterUtils.binaryBlockToDataFrame(spark(), binaryBlockFrame, mc, frameObject.getSchema());
 		} 
 		catch (DMLRuntimeException e) {
 			throw new MLContextException("DMLRuntimeException while converting frame object to DataFrame", e);
@@ -1348,4 +1313,31 @@ public class MLContextConversionUtil {
 			throw new MLContextException("DMLRuntimeException while converting frame object to 2D string array", e);
 		}
 	}
+
+	/**
+	 * Obtain JavaSparkContext from MLContextProxy.
+	 * 
+	 * @return the Java Spark Context
+	 */
+	public static JavaSparkContext jsc() {
+		return MLContextUtil.getJavaSparkContextFromProxy();
+	}
+
+	/**
+	 * Obtain SparkContext from MLContextProxy.
+	 * 
+	 * @return the Spark Context
+	 */
+	public static SparkContext sc() {
+		return MLContextUtil.getSparkContextFromProxy();
+	}
+
+	/**
+	 * Obtain SparkSession from MLContextProxy.
+	 * 
+	 * @return the Spark Session
+	 */
+	public static SparkSession spark() {
+		return MLContextUtil.getSparkSessionFromProxy();
+	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
index 4cd95d4..c4314bf 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java
@@ -47,10 +47,12 @@ import org.apache.spark.mllib.util.MLUtils;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataType;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
+import org.apache.sysml.api.MLContextProxy;
 import org.apache.sysml.conf.CompilerConfig;
 import org.apache.sysml.conf.CompilerConfig.ConfigType;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -170,12 +172,12 @@ public final class MLContextUtil {
 	 * Check that the Spark version is supported. If it isn't supported, throw
 	 * an MLContextException.
 	 * 
-	 * @param sc
-	 *            SparkContext
+	 * @param spark
+	 *            SparkSession
 	 * @throws MLContextException
 	 *             thrown if Spark version isn't supported
 	 */
-	public static void verifySparkVersionSupported(SparkContext sc) {
+	public static void verifySparkVersionSupported(SparkSession spark) {
 		String minimumRecommendedSparkVersion = null;
 		try {
 			// If this is being called using the SystemML jar file,
@@ -192,7 +194,7 @@ public final class MLContextUtil {
 				throw new MLContextException("Minimum recommended Spark version could not be determined from SystemML jar file manifest or pom.xml");
 			}
 		}
-		String sparkVersion = sc.version();
+		String sparkVersion = spark.version();
 		if (!MLContextUtil.isSparkVersionSupported(sparkVersion, minimumRecommendedSparkVersion)) {
 			throw new MLContextException(
 					"Spark " + sparkVersion + " or greater is recommended for this version of SystemML.");
@@ -1027,7 +1029,7 @@ public final class MLContextUtil {
 	 * @return the Spark Context
 	 */
 	public static SparkContext getSparkContext(MLContext mlContext) {
-		return mlContext.getSparkContext();
+		return mlContext.getSparkSession().sparkContext();
 	}
 
 	/**
@@ -1038,7 +1040,38 @@ public final class MLContextUtil {
 	 * @return the Java Spark Context
 	 */
 	public static JavaSparkContext getJavaSparkContext(MLContext mlContext) {
-		return new JavaSparkContext(mlContext.getSparkContext());
+		return new JavaSparkContext(mlContext.getSparkSession().sparkContext());
+	}
+
+	/**
+	 * Obtain the Spark Context from the MLContextProxy
+	 * 
+	 * @return the Spark Context
+	 */
+	public static SparkContext getSparkContextFromProxy() {
+		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
+		SparkContext sc = getSparkContext(activeMLContext);
+		return sc;
+	}
+
+	/**
+	 * Obtain the Java Spark Context from the MLContextProxy
+	 * 
+	 * @return the Java Spark Context
+	 */
+	public static JavaSparkContext getJavaSparkContextFromProxy() {
+		MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI();
+		JavaSparkContext jsc = getJavaSparkContext(activeMLContext);
+		return jsc;
+	}
+
+	/**
+	 * Obtain the Spark Session from the MLContextProxy
+	 * 
+	 * @return the Spark Session
+	 */
+	public static SparkSession getSparkSessionFromProxy() {
+		return ((MLContext) MLContextProxy.getActiveMLContextForAPI()).getSparkSession();
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index c2e3dd0..92946ff 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -39,6 +39,7 @@ import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.util.LongAccumulator;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.MLContextProxy;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
 import org.apache.sysml.lops.Checkpoint;
@@ -85,10 +86,11 @@ public class SparkExecutionContext extends ExecutionContext
 	private static final boolean LDEBUG = false; //local debug flag
 	
 	//internal configurations 
-	private static final boolean LAZY_SPARKCTX_CREATION = true;
-	private static final boolean ASYNCHRONOUS_VAR_DESTROY = true;
-	private static final boolean FAIR_SCHEDULER_MODE = true;
-	
+	private static boolean LAZY_SPARKCTX_CREATION = true;
+	private static boolean ASYNCHRONOUS_VAR_DESTROY = true;
+
+	public static boolean FAIR_SCHEDULER_MODE = true;
+
 	//executor memory and relative fractions as obtained from the spark configuration
 	private static SparkClusterConfig _sconf = null;
 	
@@ -198,7 +200,7 @@ public class SparkExecutionContext extends ExecutionContext
 				_spctx = new JavaSparkContext(mlCtx.getSparkContext());
 			} else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
 				org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
-				_spctx = new JavaSparkContext(mlCtx.getSparkContext());
+				_spctx = MLContextUtil.getJavaSparkContext(mlCtx);
 			}
 		}
 		else 
@@ -267,12 +269,12 @@ public class SparkExecutionContext extends ExecutionContext
 		
 		return conf;
 	}
-	
+
 	/**
 	 * Spark instructions should call this for all matrix inputs except broadcast
 	 * variables.
 	 * 
-	 * @param varname varible name
+	 * @param varname variable name
 	 * @return JavaPairRDD of MatrixIndexes-MatrixBlocks
 	 * @throws DMLRuntimeException if DMLRuntimeException occurs
 	 */

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index 318a1c8..f3ede65 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -35,6 +35,8 @@ import java.util.HashMap;
 import org.apache.sysml.lops.Lop;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.SparkSession.Builder;
 import org.apache.wink.json4j.JSONObject;
 import org.junit.After;
 import org.junit.Assert;
@@ -1788,4 +1790,28 @@ public abstract class AutomatedTestBase
 				return true;
 		return false;		
 	}
+
+	/**
+	 * Create a SystemML-preferred Spark Session.
+	 * 
+	 * @param appName the application name
+	 * @param master the master value (ie, "local", etc)
+	 * @return Spark Session
+	 */
+	public static SparkSession createSystemMLSparkSession(String appName, String master) {
+		Builder builder = SparkSession.builder();
+		if (appName != null) {
+			builder.appName(appName);
+		}
+		if (master != null) {
+			builder.master(master);
+		}
+		builder.config("spark.driver.maxResultSize", "0");
+		if (SparkExecutionContext.FAIR_SCHEDULER_MODE) {
+			builder.config("spark.scheduler.mode", "FAIR");
+		}
+		builder.config("spark.locality.wait", "5s");
+		SparkSession spark = builder.getOrCreate();
+		return spark;
+	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java
index bf5d33d..a6b6811 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameMatrixConversionTest.java
@@ -27,7 +27,6 @@ import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -37,6 +36,8 @@ import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 
@@ -55,7 +56,15 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase
 	private final static double sparsity2 = 0.1;
 	private final static double eps=0.0000000001;
 
-	 
+	private static SparkSession spark;
+	private static JavaSparkContext sc;
+
+	@BeforeClass
+	public static void setUpClass() {
+		spark = createSystemMLSparkSession("DataFrameMatrixConversionTest", "local");
+		sc = new JavaSparkContext(spark.sparkContext());
+	}
+
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -160,20 +169,11 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase
 	public void testVectorConversionWideSparseUnknown() {
 		testDataFrameConversion(true, cols3, false, true);
 	}
-	
-	/**
-	 * 
-	 * @param vector
-	 * @param singleColBlock
-	 * @param dense
-	 * @param unknownDims
-	 */
+
 	private void testDataFrameConversion(boolean vector, int cols, boolean dense, boolean unknownDims) {
 		boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; 
 		RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform;
 
-		SparkExecutionContext sec = null;
-		
 		try
 		{
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
@@ -187,17 +187,12 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase
 			int blksz = ConfigurationManager.getBlocksize();
 			MatrixCharacteristics mc1 = new MatrixCharacteristics(rows, cols, blksz, blksz, mbA.getNonZeros());
 			MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1);
-			
-			//setup spark context
-			sec = (SparkExecutionContext) ExecutionContextFactory.createContext();		
-			JavaSparkContext sc = sec.getSparkContext();
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-			
+
 			//get binary block input rdd
 			JavaPairRDD<MatrixIndexes,MatrixBlock> in = SparkExecutionContext.toMatrixJavaPairRDD(sc, mbA, blksz, blksz);
 			
 			//matrix - dataframe - matrix conversion
-			Dataset<Row> df = RDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, vector);
+			Dataset<Row> df = RDDConverterUtils.binaryBlockToDataFrame(spark, in, mc1, vector);
 			df = ( rows==rows3 ) ? df.repartition(rows) : df;
 			JavaPairRDD<MatrixIndexes,MatrixBlock> out = RDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, true, vector);
 			
@@ -212,9 +207,17 @@ public class DataFrameMatrixConversionTest extends AutomatedTestBase
 			throw new RuntimeException(ex);
 		}
 		finally {
-			sec.close();
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
 			DMLScript.rtplatform = oldPlatform;
 		}
 	}
+
+	@AfterClass
+	public static void tearDownClass() {
+		// stop underlying spark context to allow single jvm tests (otherwise the
+		// next test that tries to create a SparkContext would fail)
+		spark.stop();
+		sc = null;
+		spark = null;
+	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java
index 09628e5..452c1e1 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameRowFrameConversionTest.java
@@ -28,7 +28,6 @@ import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.Expression.ValueType;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -39,6 +38,8 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 
@@ -55,7 +56,20 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase
 	private final static double sparsity2 = 0.1;
 	private final static double eps=0.0000000001;
 
-	 
+	private static SparkSession spark;
+	private static JavaSparkContext sc;
+
+	@BeforeClass
+	public static void setUpClass() {
+		spark = SparkSession.builder()
+		.appName("DataFrameRowFrameConversionTest")
+		.master("local")
+		.config("spark.memory.offHeap.enabled", "false")
+		.config("spark.sql.codegen.wholeStage", "false")
+		.getOrCreate();
+		sc = new JavaSparkContext(spark.sparkContext());
+	}
+
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -182,20 +196,11 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase
 	public void testRowLongConversionMultiSparseUnknown() {
 		testDataFrameConversion(ValueType.INT, false, false, true);
 	}
-	
-	/**
-	 * 
-	 * @param vector
-	 * @param singleColBlock
-	 * @param dense
-	 * @param unknownDims
-	 */
+
 	private void testDataFrameConversion(ValueType vt, boolean singleColBlock, boolean dense, boolean unknownDims) {
 		boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; 
 		RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform;
 
-		SparkExecutionContext sec = null;
-		
 		try
 		{
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
@@ -212,20 +217,12 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase
 			MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros());
 			MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1);
 			ValueType[] schema = UtilFunctions.nCopies(cols, vt);
-			
-			//setup spark context
-			sec = (SparkExecutionContext) ExecutionContextFactory.createContext();		
-			JavaSparkContext sc = sec.getSparkContext();
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-
-			sc.getConf().set("spark.memory.offHeap.enabled", "false");
-			sparkSession.conf().set("spark.sql.codegen.wholeStage", "false");
 
 			//get binary block input rdd
 			JavaPairRDD<Long,FrameBlock> in = SparkExecutionContext.toFrameJavaPairRDD(sc, fbA);
 			
 			//frame - dataframe - frame conversion
-			Dataset<Row> df = FrameRDDConverterUtils.binaryBlockToDataFrame(sparkSession, in, mc1, schema);
+			Dataset<Row> df = FrameRDDConverterUtils.binaryBlockToDataFrame(spark, in, mc1, schema);
 			JavaPairRDD<Long,FrameBlock> out = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, true);
 			
 			//get output frame block
@@ -240,9 +237,17 @@ public class DataFrameRowFrameConversionTest extends AutomatedTestBase
 			throw new RuntimeException(ex);
 		}
 		finally {
-			sec.close();
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
 			DMLScript.rtplatform = oldPlatform;
 		}
 	}
+
+	@AfterClass
+	public static void tearDownClass() {
+		// stop underlying spark context to allow single jvm tests (otherwise the
+		// next test that tries to create a SparkContext would fail)
+		spark.stop();
+		sc = null;
+		spark = null;
+	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java
index 4a73376..e68eee9 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorFrameConversionTest.java
@@ -40,7 +40,6 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
@@ -52,6 +51,8 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 
@@ -73,6 +74,15 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase
 	private final static double sparsity2 = 0.1;
 	private final static double eps=0.0000000001;
 
+	private static SparkSession spark;
+	private static JavaSparkContext sc;
+
+	@BeforeClass
+	public static void setUpClass() {
+		spark = createSystemMLSparkSession("DataFrameVectorFrameConversionTest", "local");
+		sc = new JavaSparkContext(spark.sparkContext());
+	}
+
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -237,20 +247,11 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase
 	public void testVectorMixed2ConversionSparse() {
 		testDataFrameConversion(schemaMixed2, false, true, false);
 	}
-	
-	/**
-	 * 
-	 * @param vector
-	 * @param singleColBlock
-	 * @param dense
-	 * @param unknownDims
-	 */
+
 	private void testDataFrameConversion(ValueType[] schema, boolean containsID, boolean dense, boolean unknownDims) {
 		boolean oldConfig = DMLScript.USE_LOCAL_SPARK_CONFIG; 
 		RUNTIME_PLATFORM oldPlatform = DMLScript.rtplatform;
 
-		SparkExecutionContext sec = null;
-		
 		try
 		{
 			DMLScript.USE_LOCAL_SPARK_CONFIG = true;
@@ -264,14 +265,9 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase
 			int blksz = ConfigurationManager.getBlocksize();
 			MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros());
 			MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1);
-			
-			//setup spark context
-			sec = (SparkExecutionContext) ExecutionContextFactory.createContext();		
-			JavaSparkContext sc = sec.getSparkContext();
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-			
+
 			//create input data frame
-			Dataset<Row> df = createDataFrame(sparkSession, mbA, containsID, schema);
+			Dataset<Row> df = createDataFrame(spark, mbA, containsID, schema);
 			
 			//dataframe - frame conversion
 			JavaPairRDD<Long,FrameBlock> out = FrameRDDConverterUtils.dataFrameToBinaryBlock(sc, df, mc2, containsID);
@@ -289,7 +285,6 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase
 			throw new RuntimeException(ex);
 		}
 		finally {
-			sec.close();
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
 			DMLScript.rtplatform = oldPlatform;
 		}
@@ -346,4 +341,13 @@ public class DataFrameVectorFrameConversionTest extends AutomatedTestBase
 		JavaRDD<Row> rowRDD = sc.parallelize(list);
 		return sparkSession.createDataFrame(rowRDD, dfSchema);
 	}
+
+	@AfterClass
+	public static void tearDownClass() {
+		// stop underlying spark context to allow single jvm tests (otherwise the
+		// next test that tries to create a SparkContext would fail)
+		spark.stop();
+		sc = null;
+		spark = null;
+	}
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
index 92677b8..0f3d3b2 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
@@ -24,7 +24,6 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
 import java.util.ArrayList;
 import java.util.List;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.ml.linalg.DenseVector;
@@ -45,7 +44,6 @@ import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -54,6 +52,8 @@ import org.apache.sysml.runtime.util.UtilFunctions;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
 import org.apache.sysml.test.utils.TestUtils;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
 import org.junit.Test;
 
 
@@ -75,6 +75,16 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 	private final static double sparsity2 = 0.1;
 	private final static double eps=0.0000000001;
 
+	private static SparkSession spark;
+	private static MLContext ml;
+
+	@BeforeClass
+	public static void setUpClass() {
+		spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local");
+		ml = new MLContext(spark);
+		ml.setExplain(true);
+	}
+
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -239,21 +249,10 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 	public void testVectorMixed2ConversionSparse() {
 		testDataFrameScriptInput(schemaMixed2, false, true, false);
 	}
-	
-	/**
-	 * 
-	 * @param schema
-	 * @param containsID
-	 * @param dense
-	 * @param unknownDims
-	 */
+
 	private void testDataFrameScriptInput(ValueType[] schema, boolean containsID, boolean dense, boolean unknownDims) {
 		
 		//TODO fix inconsistency ml context vs jmlc register Xf
-		
-		JavaSparkContext sc = null;
-		MLContext ml = null;
-		
 		try
 		{
 			//generate input data and setup metadata
@@ -264,25 +263,15 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 			int blksz = ConfigurationManager.getBlocksize();
 			MatrixCharacteristics mc1 = new MatrixCharacteristics(rows1, cols, blksz, blksz, mbA.getNonZeros());
 			MatrixCharacteristics mc2 = unknownDims ? new MatrixCharacteristics() : new MatrixCharacteristics(mc1);
-			
-			//setup spark context
-			SparkConf conf = SparkExecutionContext.createSystemMLSparkConf()
-					.setAppName("MLContextFrameTest").setMaster("local");
-			sc = new JavaSparkContext(conf);
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-			
+
 			//create input data frame
-			Dataset<Row> df = createDataFrame(sparkSession, mbA, containsID, schema);
+			Dataset<Row> df = createDataFrame(spark, mbA, containsID, schema);
 
 			// Create full frame metadata, and empty frame metadata
 			FrameMetadata meta = new FrameMetadata(containsID ? FrameFormat.DF_WITH_INDEX :
 				FrameFormat.DF, mc2.getRows(), mc2.getCols());
 			FrameMetadata metaEmpty = new FrameMetadata();
 
-			//create mlcontext
-			ml = new MLContext(sc);
-			ml.setExplain(true);
-			
 			//run scripts and obtain result
 			Script script1 = dml(
 					"Xm = as.matrix(Xf);")
@@ -305,15 +294,6 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 			ex.printStackTrace();
 			throw new RuntimeException(ex);
 		}
-		finally {
-			// stop spark context to allow single jvm tests (otherwise the
-			// next test that tries to create a SparkContext would fail)
-			if( sc != null )
-				sc.stop();
-			// clear status mlcontext and spark exec context
-			if( ml != null )
-				ml.close();
-		}
 	}
 
 	@SuppressWarnings("resource")
@@ -367,4 +347,16 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 		JavaRDD<Row> rowRDD = sc.parallelize(list);
 		return sparkSession.createDataFrame(rowRDD, dfSchema);
 	}
+
+	@AfterClass
+	public static void tearDownClass() {
+		// stop underlying spark context to allow single jvm tests (otherwise the
+		// next test that tries to create a SparkContext would fail)
+		spark.stop();
+		spark = null;
+
+		// clear status mlcontext and spark exec context
+		ml.close();
+		ml = null;
+	}
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
index d485c48..c93968c 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
@@ -27,7 +27,6 @@ import java.util.HashMap;
 import java.util.List;
 
 import org.apache.hadoop.io.LongWritable;
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -42,6 +41,7 @@ import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
 import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.api.mlcontext.ScriptFactory;
@@ -49,7 +49,6 @@ import org.apache.sysml.parser.DataExpression;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils.LongFrameToLongWritableFrameFunction;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -99,18 +98,15 @@ public class FrameTest extends AutomatedTestBase
 		schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge);
 	}
 
-	private static SparkConf conf;
+	private static SparkSession spark;
 	private static JavaSparkContext sc;
 	private static MLContext ml;
 
 	@BeforeClass
 	public static void setUpClass() {
-		if (conf == null)
-			conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("FrameTest").setMaster("local");
-		if (sc == null)
-			sc = new JavaSparkContext(conf);
-		ml = new MLContext(sc);
+		spark = createSystemMLSparkSession("FrameTest", "local");
+		ml = new MLContext(spark);
+		sc = MLContextUtil.getJavaSparkContext(ml);
 	}
 
 	@Override
@@ -237,16 +233,15 @@ public class FrameTest extends AutomatedTestBase
 		if(bFromDataFrame)
 		{
 			//Create DataFrame for input A 
-			SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
 			StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schema, false);
 
 			JavaRDD<Row> rowRDDA = FrameRDDConverterUtils.csvToRowRDD(sc, input("A"), DataExpression.DEFAULT_DELIM_DELIMITER, schema);
-			dfA = sparkSession.createDataFrame(rowRDDA, dfSchemaA);
+			dfA = spark.createDataFrame(rowRDDA, dfSchemaA);
 			
 			//Create DataFrame for input B 
 			StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
 			JavaRDD<Row> rowRDDB = FrameRDDConverterUtils.csvToRowRDD(sc, input("B"), DataExpression.DEFAULT_DELIM_DELIMITER, schemaB);
-			dfB = sparkSession.createDataFrame(rowRDDB, dfSchemaB);
+			dfB = spark.createDataFrame(rowRDDB, dfSchemaB);
 		}
 
 		try 
@@ -386,11 +381,11 @@ public class FrameTest extends AutomatedTestBase
 
 	@AfterClass
 	public static void tearDownClass() {
-		// stop spark context to allow single jvm tests (otherwise the
+		// stop underlying spark context to allow single jvm tests (otherwise the
 		// next test that tries to create a SparkContext would fail)
-		sc.stop();
+		spark.stop();
 		sc = null;
-		conf = null;
+		spark = null;
 
 		// clear status mlcontext and spark exec context
 		ml.close();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index eeeb925..f9f5fbd 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -26,7 +26,6 @@ import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
@@ -34,10 +33,12 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.rdd.RDD;
+import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -47,7 +48,6 @@ import org.apache.sysml.api.mlcontext.ScriptFactory;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -76,7 +76,7 @@ public class GNMFTest extends AutomatedTestBase
 	int numRegisteredInputs;
 	int numRegisteredOutputs;
 
-	private static SparkConf conf;
+	private static SparkSession spark;
 	private static JavaSparkContext sc;
 	private static MLContext ml;
 
@@ -87,12 +87,9 @@ public class GNMFTest extends AutomatedTestBase
 
 	@BeforeClass
 	public static void setUpClass() {
-		if (conf == null)
-			conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("GNMFTest").setMaster("local");
-		if (sc == null)
-			sc = new JavaSparkContext(conf);
-		ml = new MLContext(sc);
+		spark = createSystemMLSparkSession("GNMFTest", "local");
+		ml = new MLContext(spark);
+		sc = MLContextUtil.getJavaSparkContext(ml);
 	}
 
 	@Parameters
@@ -267,11 +264,11 @@ public class GNMFTest extends AutomatedTestBase
 
 	@AfterClass
 	public static void tearDownClass() {
-		// stop spark context to allow single jvm tests (otherwise the
+		// stop underlying spark context to allow single jvm tests (otherwise the
 		// next test that tries to create a SparkContext would fail)
-		sc.stop();
+		spark.stop();
 		sc = null;
-		conf = null;
+		spark = null;
 
 		// clear status mlcontext and spark exec context
 		ml.close();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
index 6dd74d3..bab719e 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
@@ -27,7 +27,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.rdd.RDD;
@@ -43,12 +42,12 @@ import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
 import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
 import org.apache.sysml.api.mlcontext.MatrixMetadata;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.parser.Expression.ValueType;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.test.integration.AutomatedTestBase;
@@ -73,19 +72,16 @@ public class MLContextFrameTest extends AutomatedTestBase {
 		ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME
 	};
 
-	private static SparkConf conf;
+	private static SparkSession spark;
 	private static JavaSparkContext sc;
 	private static MLContext ml;
 	private static String CSV_DELIM = ",";
 
 	@BeforeClass
 	public static void setUpClass() {
-		if (conf == null)
-			conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("MLContextFrameTest").setMaster("local");
-		if (sc == null)
-			sc = new JavaSparkContext(conf);
-		ml = new MLContext(sc);
+		spark = createSystemMLSparkSession("MLContextFrameTest", "local");
+		ml = new MLContext(spark);
+		sc = MLContextUtil.getJavaSparkContext(ml);
 		ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS);
 	}
 
@@ -238,11 +234,10 @@ public class MLContextFrameTest extends AutomatedTestBase {
 				JavaRDD<Row> javaRddRowB = FrameRDDConverterUtils.csvToRowRDD(sc, javaRDDB, CSV_DELIM, schemaB);
 
 				// Create DataFrame
-				SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
 				StructType dfSchemaA = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaA, false);
-				Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, dfSchemaA);
+				Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, dfSchemaA);
 				StructType dfSchemaB = FrameRDDConverterUtils.convertFrameSchemaToDFSchema(schemaB, false);
-				Dataset<Row> dataFrameB = sparkSession.createDataFrame(javaRddRowB, dfSchemaB);
+				Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, dfSchemaB);
 				if (script_type == SCRIPT_TYPE.DML)
 					script = dml("A[2:3,2:4]=B;C=A[2:3,2:3]").in("A", dataFrameA, fmA).in("B", dataFrameB, fmB).out("A")
 							.out("C");
@@ -492,18 +487,16 @@ public class MLContextFrameTest extends AutomatedTestBase {
 		JavaRDD<Row> javaRddRowA = FrameRDDConverterUtils.csvToRowRDD(sc, javaRddStringA, CSV_DELIM, schema);
 		JavaRDD<Row> javaRddRowB = javaRddStringB.map(new CommaSeparatedValueStringToDoubleArrayRow());
 
-		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-
 		List<StructField> fieldsA = new ArrayList<StructField>();
 		fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType, true));
 		fieldsA.add(DataTypes.createStructField("2", DataTypes.DoubleType, true));
 		StructType schemaA = DataTypes.createStructType(fieldsA);
-		Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA);
+		Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA);
 
 		List<StructField> fieldsB = new ArrayList<StructField>();
 		fieldsB.add(DataTypes.createStructField("1", DataTypes.DoubleType, true));
 		StructType schemaB = DataTypes.createStructType(fieldsB);
-		Dataset<Row> dataFrameB = sparkSession.createDataFrame(javaRddRowB, schemaB);
+		Dataset<Row> dataFrameB = spark.createDataFrame(javaRddRowB, schemaB);
 
 		String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: true ,recode: [ 1, 2 ]}\");\n"
 				+ "C = tA %*% B;\n" + "M = s * C;";
@@ -529,14 +522,12 @@ public class MLContextFrameTest extends AutomatedTestBase {
 
 		JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); 
 
-		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-
 		List<StructField> fieldsA = new ArrayList<StructField>();
 		fieldsA.add(DataTypes.createStructField("myID", DataTypes.StringType, true));
 		fieldsA.add(DataTypes.createStructField("FeatureName", DataTypes.StringType, true));
 		fieldsA.add(DataTypes.createStructField("FeatureValue", DataTypes.IntegerType, true));
 		StructType schemaA = DataTypes.createStructType(fieldsA);
-		Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA);
+		Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA);
 
 		String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ myID, FeatureName ]}\");";
 
@@ -571,14 +562,12 @@ public class MLContextFrameTest extends AutomatedTestBase {
 
 		JavaRDD<Row> javaRddRowA = sc. parallelize( Arrays.asList(rowsA)); 
 
-		SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-
 		List<StructField> fieldsA = new ArrayList<StructField>();
 		fieldsA.add(DataTypes.createStructField("featureName", DataTypes.StringType, true));
 		fieldsA.add(DataTypes.createStructField("featureValue", DataTypes.IntegerType, true));
 		fieldsA.add(DataTypes.createStructField("id", DataTypes.StringType, true));
 		StructType schemaA = DataTypes.createStructType(fieldsA);
-		Dataset<Row> dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA);
+		Dataset<Row> dataFrameA = spark.createDataFrame(javaRddRowA, schemaA);
 
 		String dmlString = "[tA, tAM] = transformencode (target = A, spec = \"{ids: false ,recode: [ featureName, id ]}\");";
 
@@ -621,15 +610,13 @@ public class MLContextFrameTest extends AutomatedTestBase {
 	// JavaRDD<Row> javaRddRowA = javaRddStringA.map(new
 	// CommaSeparatedValueStringToRow());
 	//
-	// SparkSession sparkSession = SparkSession.builder().sparkContext(sc.sc()).getOrCreate();
-	//
 	// List<StructField> fieldsA = new ArrayList<StructField>();
 	// fieldsA.add(DataTypes.createStructField("1", DataTypes.StringType,
 	// true));
 	// fieldsA.add(DataTypes.createStructField("2", DataTypes.StringType,
 	// true));
 	// StructType schemaA = DataTypes.createStructType(fieldsA);
-	// DataFrame dataFrameA = sparkSession.createDataFrame(javaRddRowA, schemaA);
+	// DataFrame dataFrameA = spark.createDataFrame(javaRddRowA, schemaA);
 	//
 	// String dmlString = "[tA, tAM] = transformencode (target = A, spec =
 	// \"{ids: true ,recode: [ 1, 2 ]}\");\n";
@@ -664,11 +651,11 @@ public class MLContextFrameTest extends AutomatedTestBase {
 
 	@AfterClass
 	public static void tearDownClass() {
-		// stop spark context to allow single jvm tests (otherwise the
+		// stop underlying spark context to allow single jvm tests (otherwise the
 		// next test that tries to create a SparkContext would fail)
-		sc.stop();
+		spark.stop();
 		sc = null;
-		conf = null;
+		spark = null;
 
 		// clear status mlcontext and spark exec context
 		ml.close();

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
index de46c2a..c418a6f 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
@@ -23,14 +23,12 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
 
 import java.io.File;
 
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.After;
@@ -92,12 +90,10 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase
 		DMLScript.rtplatform = platform;
 		
 		//create mlcontext
-		SparkConf conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("MLContextFrameTest").setMaster("local");
-		JavaSparkContext sc = new JavaSparkContext(conf);
-		MLContext ml = new MLContext(sc);
+		SparkSession spark = createSystemMLSparkSession("MLContextMultipleScriptsTest", "local");
+		MLContext ml = new MLContext(spark);
 		ml.setExplain(true);
-		
+
 		String dml1 = baseDirectory + File.separator + "MultiScript1.dml";
 		String dml2 = baseDirectory + File.separator + (wRead?"MultiScript2b.dml":"MultiScript2.dml");
 		String dml3 = baseDirectory + File.separator + (wRead?"MultiScript3b.dml":"MultiScript3.dml");
@@ -119,9 +115,9 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase
 		finally {
 			DMLScript.rtplatform = oldplatform;
 			
-			// stop spark context to allow single jvm tests (otherwise the
+			// stop underlying spark context to allow single jvm tests (otherwise the
 			// next test that tries to create a SparkContext would fail)
-			sc.stop();
+			spark.stop();
 			// clear status mlcontext and spark exec context
 			ml.close();
 		}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/9c19b477/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
index c9a3dbc..6391919 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
@@ -23,14 +23,12 @@ import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
 
 import java.io.File;
 
-import org.apache.spark.SparkConf;
-import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
 import org.junit.After;
@@ -92,12 +90,10 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase
 		DMLScript.rtplatform = platform;
 		
 		//create mlcontext
-		SparkConf conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("MLContextFrameTest").setMaster("local");
-		JavaSparkContext sc = new JavaSparkContext(conf);
-		MLContext ml = new MLContext(sc);
+		SparkSession spark = createSystemMLSparkSession("MLContextScratchCleanupTest", "local");
+		MLContext ml = new MLContext(spark);
 		ml.setExplain(true);
-		
+
 		String dml1 = baseDirectory + File.separator + "ScratchCleanup1.dml";
 		String dml2 = baseDirectory + File.separator + (wRead?"ScratchCleanup2b.dml":"ScratchCleanup2.dml");
 		
@@ -120,9 +116,9 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase
 		finally {
 			DMLScript.rtplatform = oldplatform;
 			
-			// stop spark context to allow single jvm tests (otherwise the
+			// stop underlying spark context to allow single jvm tests (otherwise the
 			// next test that tries to create a SparkContext would fail)
-			sc.stop();
+			spark.stop();
 			// clear status mlcontext and spark exec context
 			ml.close();
 		}