You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/07/19 18:31:15 UTC

systemml git commit: [SYSTEMML-1777] MLContextTestBase class for MLContext testing

Repository: systemml
Updated Branches:
  refs/heads/master 0ae2b4f77 -> ec38b3790


[SYSTEMML-1777] MLContextTestBase class for MLContext testing

Create abstract MLContextTestBase class that contains setup and shutdown
code for MLContext tests. This removes boilerplate code from MLContext
test classes that extend MLContextTestBase.

Closes #580.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/ec38b379
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/ec38b379
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/ec38b379

Branch: refs/heads/master
Commit: ec38b3790f11792d3337c35439954d422d6eb60b
Parents: 0ae2b4f
Author: Deron Eriksson <de...@apache.org>
Authored: Wed Jul 19 11:13:51 2017 -0700
Committer: Deron Eriksson <de...@apache.org>
Committed: Wed Jul 19 11:13:51 2017 -0700

----------------------------------------------------------------------
 .../mlcontext/DataFrameVectorScriptTest.java    |  29 +---
 .../functions/mlcontext/FrameTest.java          |  40 +-----
 .../functions/mlcontext/GNMFTest.java           |  40 +-----
 .../mlcontext/MLContextFrameTest.java           |  41 +-----
 .../mlcontext/MLContextMultipleScriptsTest.java |   4 -
 .../mlcontext/MLContextOutputBlocksizeTest.java |  51 +------
 .../mlcontext/MLContextParforDatasetTest.java   |  52 +------
 .../mlcontext/MLContextScratchCleanupTest.java  |   4 -
 .../integration/mlcontext/MLContextTest.java    | 143 +++----------------
 .../mlcontext/MLContextTestBase.java            |  89 ++++++++++++
 .../test/integration/scripts/nn/NNTest.java     |  46 +-----
 11 files changed, 123 insertions(+), 416 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
index 65aee8e..55b8371 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/DataFrameVectorScriptTest.java
@@ -38,7 +38,6 @@ import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.conf.ConfigurationManager;
@@ -49,15 +48,13 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.runtime.util.UtilFunctions;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class DataFrameVectorScriptTest extends AutomatedTestBase
+public class DataFrameVectorScriptTest extends MLContextTestBase
 {
 	private final static String TEST_DIR = "functions/mlcontext/";
 	private final static String TEST_NAME = "DataFrameConversion";
@@ -75,16 +72,6 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 	private final static double sparsity2 = 0.1;
 	private final static double eps=0.0000000001;
 
-	private static SparkSession spark;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		spark = createSystemMLSparkSession("DataFrameVectorScriptTest", "local");
-		ml = new MLContext(spark);
-		ml.setExplain(true);
-	}
-
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"A", "B"}));
@@ -343,16 +330,4 @@ public class DataFrameVectorScriptTest extends AutomatedTestBase
 		JavaRDD<Row> rowRDD = sc.parallelize(list);
 		return sparkSession.createDataFrame(rowRDD, dfSchema);
 	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		spark = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
index c93968c..382f433 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/FrameTest.java
@@ -29,10 +29,8 @@ import java.util.List;
 import org.apache.hadoop.io.LongWritable;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
@@ -40,8 +38,6 @@ import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.api.mlcontext.ScriptFactory;
@@ -57,17 +53,14 @@ import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.runtime.util.UtilFunctions;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.TestConfiguration;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class FrameTest extends AutomatedTestBase 
+public class FrameTest extends MLContextTestBase
 {
 	private final static String TEST_DIR = "functions/frame/";
 	private final static String TEST_NAME = "FrameGeneral";
@@ -98,17 +91,6 @@ public class FrameTest extends AutomatedTestBase
 		schemaMixedLarge = (ValueType[]) schemaMixedLargeList.toArray(schemaMixedLarge);
 	}
 
-	private static SparkSession spark;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		spark = createSystemMLSparkSession("FrameTest", "local");
-		ml = new MLContext(spark);
-		sc = MLContextUtil.getJavaSparkContext(ml);
-	}
-
 	@Override
 	public void setUp() {
 		addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, 
@@ -373,22 +355,4 @@ public class FrameTest extends AutomatedTestBase
 							", not same as the R value " + val2);
 			}
 	}
-
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		sc = null;
-		spark = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index 76deec5..44f1f15 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -28,17 +28,13 @@ import java.util.List;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.rdd.RDD;
-import org.apache.spark.sql.SparkSession;
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.mlcontext.MLContext;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -55,19 +51,16 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysml.runtime.util.MapReduceTool;
-import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
 
 @RunWith(value = Parameterized.class)
-public class GNMFTest extends AutomatedTestBase 
+public class GNMFTest extends MLContextTestBase
 {
 	private final static String TEST_DIR = "applications/gnmf/";
 	private final static String TEST_NAME = "GNMF";
@@ -76,22 +69,11 @@ public class GNMFTest extends AutomatedTestBase
 	int numRegisteredInputs;
 	int numRegisteredOutputs;
 
-	private static SparkSession spark;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
-
 	public GNMFTest(int in, int out) {
 		numRegisteredInputs = in;
 		numRegisteredOutputs = out;
 	}
 
-	@BeforeClass
-	public static void setUpClass() {
-		spark = createSystemMLSparkSession("GNMFTest", "local");
-		ml = new MLContext(spark);
-		sc = MLContextUtil.getJavaSparkContext(ml);
-	}
-
 	@Parameters
 	 public static Collection<Object[]> data() {
 	   Object[][] data = new Object[][] { { 0, 0 }, { 3, 2 }, { 2, 2 }, { 2, 1 }, { 2, 0 }, { 3, 0 }};
@@ -256,25 +238,7 @@ public class GNMFTest extends AutomatedTestBase
 			DMLScript.USE_LOCAL_SPARK_CONFIG = oldConfig;
 		}
 	}
-	
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		sc = null;
-		spark = null;
 
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
-	
 	public static class StringToMatrixEntry implements Function<String, MatrixEntry> {
 
 		private static final long serialVersionUID = 7456391906436606324L;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
index bab719e..a7d12a5 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextFrameTest.java
@@ -28,21 +28,17 @@ import java.util.Arrays;
 import java.util.List;
 
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
 import org.apache.sysml.api.mlcontext.FrameFormat;
 import org.apache.sysml.api.mlcontext.FrameMetadata;
 import org.apache.sysml.api.mlcontext.FrameSchema;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
 import org.apache.sysml.api.mlcontext.MatrixMetadata;
@@ -50,19 +46,14 @@ import org.apache.sysml.api.mlcontext.Script;
 import org.apache.sysml.parser.Expression.ValueType;
 import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.integration.mlcontext.MLContextTest.CommaSeparatedValueStringToDoubleArrayRow;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 
 import scala.collection.Iterator;
 
-public class MLContextFrameTest extends AutomatedTestBase {
-	protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
-	protected final static String TEST_NAME = "MLContextFrame";
+public class MLContextFrameTest extends MLContextTestBase {
 
 	public static enum SCRIPT_TYPE {
 		DML, PYDML
@@ -72,25 +63,14 @@ public class MLContextFrameTest extends AutomatedTestBase {
 		ANY, FILE, JAVA_RDD_STR_CSV, JAVA_RDD_STR_IJV, RDD_STR_CSV, RDD_STR_IJV, DATAFRAME
 	};
 
-	private static SparkSession spark;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
 	private static String CSV_DELIM = ",";
 
 	@BeforeClass
 	public static void setUpClass() {
-		spark = createSystemMLSparkSession("MLContextFrameTest", "local");
-		ml = new MLContext(spark);
-		sc = MLContextUtil.getJavaSparkContext(ml);
+		MLContextTestBase.setUpClass();
 		ml.setExplainLevel(ExplainLevel.RECOMPILE_HOPS);
 	}
 
-	@Override
-	public void setUp() {
-		addTestConfiguration(TEST_DIR, TEST_NAME);
-		getAndLoadTestConfiguration(TEST_NAME);
-	}
-
 	@Test
 	public void testFrameJavaRDD_CSV_DML() {
 		testFrame(FrameFormat.CSV, SCRIPT_TYPE.DML, IO_TYPE.JAVA_RDD_STR_CSV, IO_TYPE.ANY);
@@ -644,21 +624,4 @@ public class MLContextFrameTest extends AutomatedTestBase {
 	// }
 	// }
 
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		sc = null;
-		spark = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
index c418a6f..9b58322 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextMultipleScriptsTest.java
@@ -80,10 +80,6 @@ public class MLContextMultipleScriptsTest extends AutomatedTestBase
 		runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true);
 	}
 
-	/**
-	 * 
-	 * @param platform
-	 */
 	private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) 
 	{
 		RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
index fbc413b..af6028c 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextOutputBlocksizeTest.java
@@ -21,10 +21,7 @@ package org.apache.sysml.test.integration.mlcontext;
 
 import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
@@ -36,44 +33,15 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.util.DataConverter;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
-
-public class MLContextOutputBlocksizeTest extends AutomatedTestBase
+public class MLContextOutputBlocksizeTest extends MLContextTestBase
 {
-	protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
-	protected final static String TEST_NAME = "MLContext";
-
 	private final static int rows = 100;
 	private final static int cols = 63;
 	private final static double sparsity = 0.7;
 
-	private static SparkConf conf;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		if (conf == null)
-			conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("MLContextTest").setMaster("local");
-		if (sc == null)
-			sc = new JavaSparkContext(conf);
-		ml = new MLContext(sc);
-	}
-
-	@Override
-	public void setUp() {
-		addTestConfiguration(TEST_DIR, TEST_NAME);
-		getAndLoadTestConfiguration(TEST_NAME);
-	}
-
-
 	@Test
 	public void testOutputBlocksizeTextcell() {
 		runMLContextOutputBlocksizeTest("text");
@@ -131,21 +99,4 @@ public class MLContextOutputBlocksizeTest extends AutomatedTestBase
 		}
 	}
 
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		sc.stop();
-		sc = null;
-		conf = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
index 68b1373..0bcecf4 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextParforDatasetTest.java
@@ -21,18 +21,15 @@ package org.apache.sysml.test.integration.mlcontext;
 
 import static org.apache.sysml.api.mlcontext.ScriptFactory.dml;
 
-import org.apache.spark.SparkConf;
 import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
-import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
 import org.apache.sysml.api.mlcontext.MatrixMetadata;
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.api.mlcontext.MLContext.ExplainLevel;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
@@ -41,43 +38,16 @@ import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
 import org.apache.sysml.runtime.util.DataConverter;
-import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
-import org.junit.After;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 
-public class MLContextParforDatasetTest extends AutomatedTestBase 
+public class MLContextParforDatasetTest extends MLContextTestBase
 {
-	protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
-	protected final static String TEST_NAME = "MLContext";
 
 	private final static int rows = 100;
 	private final static int cols = 1600;
 	private final static double sparsity = 0.7;
-	
-	private static SparkConf conf;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		if (conf == null)
-			conf = SparkExecutionContext.createSystemMLSparkConf()
-				.setAppName("MLContextTest").setMaster("local");
-		if (sc == null)
-			sc = new JavaSparkContext(conf);
-		ml = new MLContext(sc);
-	}
-
-	@Override
-	public void setUp() {
-		addTestConfiguration(TEST_DIR, TEST_NAME);
-		getAndLoadTestConfiguration(TEST_NAME);
-	}
-
 
 	@Test
 	public void testParforDatasetVector() {
@@ -174,22 +144,4 @@ public class MLContextParforDatasetTest extends AutomatedTestBase
 			InfrastructureAnalyzer.setLocalMaxMemory(oldmem);	
 		}
 	}
-
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		sc.stop();
-		sc = null;
-		conf = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
index 6391919..e5e575b 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextScratchCleanupTest.java
@@ -80,10 +80,6 @@ public class MLContextScratchCleanupTest extends AutomatedTestBase
 		runMLContextTestMultipleScript(RUNTIME_PLATFORM.SPARK, true);
 	}
 
-	/**
-	 * 
-	 * @param platform
-	 */
 	private void runMLContextTestMultipleScript(RUNTIME_PLATFORM platform, boolean wRead) 
 	{
 		RUNTIME_PLATFORM oldplatform = DMLScript.rtplatform;

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
index 8bb09e2..88d1a28 100644
--- a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTest.java
@@ -45,7 +45,6 @@ import java.util.Map;
 
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.apache.spark.ml.linalg.Vector;
 import org.apache.spark.ml.linalg.VectorUDT;
@@ -54,14 +53,11 @@ import org.apache.spark.rdd.RDD;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructField;
 import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.api.mlcontext.MLContext;
 import org.apache.sysml.api.mlcontext.MLContextConversionUtil;
 import org.apache.sysml.api.mlcontext.MLContextException;
-import org.apache.sysml.api.mlcontext.MLContextUtil;
 import org.apache.sysml.api.mlcontext.MLResults;
 import org.apache.sysml.api.mlcontext.Matrix;
 import org.apache.sysml.api.mlcontext.MatrixFormat;
@@ -73,11 +69,7 @@ import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
-import org.junit.BeforeClass;
 import org.junit.Test;
 
 import scala.Tuple2;
@@ -86,26 +78,7 @@ import scala.collection.Iterator;
 import scala.collection.JavaConversions;
 import scala.collection.Seq;
 
-public class MLContextTest extends AutomatedTestBase {
-	protected final static String TEST_DIR = "org/apache/sysml/api/mlcontext";
-	protected final static String TEST_NAME = "MLContext";
-
-	private static SparkSession spark;
-	private static JavaSparkContext sc;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		spark = createSystemMLSparkSession("MLContextTest", "local");
-		ml = new MLContext(spark);
-		sc = MLContextUtil.getJavaSparkContext(ml);
-	}
-
-	@Override
-	public void setUp() {
-		addTestConfiguration(TEST_DIR, TEST_NAME);
-		getAndLoadTestConfiguration(TEST_NAME);
-	}
+public class MLContextTest extends MLContextTestBase {
 
 	@Test
 	public void testCreateDMLScriptBasedOnStringAndExecute() {
@@ -710,9 +683,12 @@ public class MLContextTest extends AutomatedTestBase {
 		System.out.println("MLContextTest - DataFrame sum DML, mllib vector with ID column");
 
 		List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
 		JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list);
 
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow());
@@ -734,9 +710,12 @@ public class MLContextTest extends AutomatedTestBase {
 		System.out.println("MLContextTest - DataFrame sum PYDML, mllib vector with ID column");
 
 		List<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> list = new ArrayList<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>>();
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0, org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0, org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
-		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0, org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(1.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(1.0, 2.0, 3.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(2.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(4.0, 5.0, 6.0)));
+		list.add(new Tuple2<Double, org.apache.spark.mllib.linalg.Vector>(3.0,
+				org.apache.spark.mllib.linalg.Vectors.dense(7.0, 8.0, 9.0)));
 		JavaRDD<Tuple2<Double, org.apache.spark.mllib.linalg.Vector>> javaRddTuple = sc.parallelize(list);
 
 		JavaRDD<Row> javaRddRow = javaRddTuple.map(new DoubleMllibVectorRow());
@@ -2576,7 +2555,8 @@ public class MLContextTest extends AutomatedTestBase {
 	@Test
 	public void testPrintFormattingMultipleExpressions() {
 		System.out.println("MLContextTest - print formatting multiple expressions");
-		Script script = dml("a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
+		Script script = dml(
+				"a='hello'; b='goodbye'; c=4; d=3; e=3.0; f=5.0; g=FALSE; print('%s %d %f %b', (a+b), (c-d), (e*f), !g);");
 		setExpectedStdOut("hellogoodbye 1 15.000000 true");
 		ml.execute(script);
 	}
@@ -2732,7 +2712,7 @@ public class MLContextTest extends AutomatedTestBase {
 	public void testOutputListDML() {
 		System.out.println("MLContextTest - output specified as List DML");
 
-		List<String> outputs = Arrays.asList("x","y");
+		List<String> outputs = Arrays.asList("x", "y");
 		Script script = dml("a=1;x=a+1;y=x+1").out(outputs);
 		MLResults results = ml.execute(script);
 		Assert.assertEquals(2, results.getLong("x"));
@@ -2743,7 +2723,7 @@ public class MLContextTest extends AutomatedTestBase {
 	public void testOutputListPYDML() {
 		System.out.println("MLContextTest - output specified as List PYDML");
 
-		List<String> outputs = Arrays.asList("x","y");
+		List<String> outputs = Arrays.asList("x", "y");
 		Script script = pydml("a=1\nx=a+1\ny=x+1").out(outputs);
 		MLResults results = ml.execute(script);
 		Assert.assertEquals(2, results.getLong("x"));
@@ -2755,7 +2735,7 @@ public class MLContextTest extends AutomatedTestBase {
 	public void testOutputScalaSeqDML() {
 		System.out.println("MLContextTest - output specified as Scala Seq DML");
 
-		List outputs = Arrays.asList("x","y");
+		List outputs = Arrays.asList("x", "y");
 		Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
 		Script script = dml("a=1;x=a+1;y=x+1").out(seq);
 		MLResults results = ml.execute(script);
@@ -2768,7 +2748,7 @@ public class MLContextTest extends AutomatedTestBase {
 	public void testOutputScalaSeqPYDML() {
 		System.out.println("MLContextTest - output specified as Scala Seq PYDML");
 
-		List outputs = Arrays.asList("x","y");
+		List outputs = Arrays.asList("x", "y");
 		Seq seq = JavaConversions.asScalaBuffer(outputs).toSeq();
 		Script script = pydml("a=1\nx=a+1\ny=x+1").out(seq);
 		MLResults results = ml.execute(script);
@@ -2776,89 +2756,4 @@ public class MLContextTest extends AutomatedTestBase {
 		Assert.assertEquals(3, results.getLong("y"));
 	}
 
-	// NOTE: Uncomment these tests once they work
-
-	// @SuppressWarnings({ "rawtypes", "unchecked" })
-	// @Test
-	// public void testInputTupleSeqWithAndWithoutMetadataDML() {
-	// System.out.println("MLContextTest - Tuple sequence with and without
-	// metadata DML");
-	//
-	// List<String> list1 = new ArrayList<String>();
-	// list1.add("1,2");
-	// list1.add("3,4");
-	// JavaRDD<String> javaRDD1 = sc.parallelize(list1);
-	// RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
-	//
-	// List<String> list2 = new ArrayList<String>();
-	// list2.add("5,6");
-	// list2.add("7,8");
-	// JavaRDD<String> javaRDD2 = sc.parallelize(list2);
-	// RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
-	//
-	// MatrixMetadata mm1 = new MatrixMetadata(2, 2);
-	//
-	// Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
-	// Tuple2 tuple2 = new Tuple2("m2", rdd2);
-	// List tupleList = new ArrayList();
-	// tupleList.add(tuple1);
-	// tupleList.add(tuple2);
-	// Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
-	//
-	// Script script =
-	// dml("print('sums: ' + sum(m1) + ' ' + sum(m2));").in(seq);
-	// setExpectedStdOut("sums: 10.0 26.0");
-	// ml.execute(script);
-	// }
-	//
-	// @SuppressWarnings({ "rawtypes", "unchecked" })
-	// @Test
-	// public void testInputTupleSeqWithAndWithoutMetadataPYDML() {
-	// System.out.println("MLContextTest - Tuple sequence with and without
-	// metadata PYDML");
-	//
-	// List<String> list1 = new ArrayList<String>();
-	// list1.add("1,2");
-	// list1.add("3,4");
-	// JavaRDD<String> javaRDD1 = sc.parallelize(list1);
-	// RDD<String> rdd1 = JavaRDD.toRDD(javaRDD1);
-	//
-	// List<String> list2 = new ArrayList<String>();
-	// list2.add("5,6");
-	// list2.add("7,8");
-	// JavaRDD<String> javaRDD2 = sc.parallelize(list2);
-	// RDD<String> rdd2 = JavaRDD.toRDD(javaRDD2);
-	//
-	// MatrixMetadata mm1 = new MatrixMetadata(2, 2);
-	//
-	// Tuple3 tuple1 = new Tuple3("m1", rdd1, mm1);
-	// Tuple2 tuple2 = new Tuple2("m2", rdd2);
-	// List tupleList = new ArrayList();
-	// tupleList.add(tuple1);
-	// tupleList.add(tuple2);
-	// Seq seq = JavaConversions.asScalaBuffer(tupleList).toSeq();
-	//
-	// Script script =
-	// pydml("print('sums: ' + sum(m1) + ' ' + sum(m2))").in(seq);
-	// setExpectedStdOut("sums: 10.0 26.0");
-	// ml.execute(script);
-	// }
-
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		sc = null;
-		spark = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
new file mode 100644
index 0000000..380fb3f
--- /dev/null
+++ b/src/test/java/org/apache/sysml/test/integration/mlcontext/MLContextTestBase.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysml.test.integration.mlcontext;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.SparkSession;
+import org.apache.sysml.api.mlcontext.MLContext;
+import org.apache.sysml.api.mlcontext.MLContextUtil;
+import org.apache.sysml.test.integration.AutomatedTestBase;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+
+/**
+ * Abstract class that can be used for MLContext tests.
+ * <p>
+ * Note that if using the setUp() method of MLContextTestBase, the test directory
+ * and test name can be specified if needed in the subclass.
+ * <p>
+ * 
+ * Example:
+ * 
+ * <pre>
+ * public MLContextTestExample() {
+ * 	testDir = this.getClass().getPackage().getName().replace(".", File.separator);
+ * 	testName = this.getClass().getSimpleName();
+ * }
+ * </pre>
+ *
+ */
+public abstract class MLContextTestBase extends AutomatedTestBase {
+
+	protected static SparkSession spark;
+	protected static JavaSparkContext sc;
+	protected static MLContext ml;
+
+	protected String testDir = null;
+	protected String testName = null;
+
+	@Override
+	public void setUp() {
+		Class<? extends MLContextTestBase> clazz = this.getClass();
+		String dir = (testDir == null) ? "org/apache/sysml/api/mlcontext" : testDir;
+		String name = (testName == null) ? clazz.getSimpleName() : testName;
+
+		addTestConfiguration(dir, name);
+		getAndLoadTestConfiguration(name);
+	}
+
+	@BeforeClass
+	public static void setUpClass() {
+		spark = createSystemMLSparkSession("SystemML MLContext Test", "local");
+		ml = new MLContext(spark);
+		sc = MLContextUtil.getJavaSparkContext(ml);
+	}
+
+	@After
+	public void tearDown() {
+		super.tearDown();
+	}
+
+	@AfterClass
+	public static void tearDownClass() {
+		// stop underlying spark context to allow single jvm tests (otherwise
+		// the next test that tries to create a SparkContext would fail)
+		spark.stop();
+		sc = null;
+		spark = null;
+		ml.close();
+		ml = null;
+	}
+}

http://git-wip-us.apache.org/repos/asf/systemml/blob/ec38b379/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
index d86b707..92b9f67 100644
--- a/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/scripts/nn/NNTest.java
@@ -19,42 +19,20 @@
 
 package org.apache.sysml.test.integration.scripts.nn;
 
-import org.apache.spark.sql.SparkSession;
-import org.apache.sysml.api.mlcontext.MLContext;
+import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
+
 import org.apache.sysml.api.mlcontext.Script;
-import org.apache.sysml.test.integration.AutomatedTestBase;
-import org.junit.After;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
+import org.apache.sysml.test.integration.mlcontext.MLContextTestBase;
 import org.junit.Test;
 
-import static org.apache.sysml.api.mlcontext.ScriptFactory.dmlFromFile;
-
 /**
  * Test the SystemML deep learning library, `nn`.
  */
-public class NNTest extends AutomatedTestBase {
+public class NNTest extends MLContextTestBase {
 
-	private static final String TEST_NAME = "NNTest";
-	private static final String TEST_DIR = "scripts/";
 	private static final String TEST_SCRIPT = "scripts/nn/test/run_tests.dml";
 	private static final String ERROR_STRING = "ERROR:";
 
-	private static SparkSession spark;
-	private static MLContext ml;
-
-	@BeforeClass
-	public static void setUpClass() {
-		spark = createSystemMLSparkSession("MLContextTest", "local");
-		ml = new MLContext(spark);
-	}
-
-	@Override
-	public void setUp() {
-		addTestConfiguration(TEST_DIR, TEST_NAME);
-		getAndLoadTestConfiguration(TEST_NAME);
-	}
-
 	@Test
 	public void testNNLibrary() {
 		Script script = dmlFromFile(TEST_SCRIPT);
@@ -62,20 +40,4 @@ public class NNTest extends AutomatedTestBase {
 		ml.execute(script);
 	}
 
-	@After
-	public void tearDown() {
-		super.tearDown();
-	}
-
-	@AfterClass
-	public static void tearDownClass() {
-		// stop underlying spark context to allow single jvm tests (otherwise the
-		// next test that tries to create a SparkContext would fail)
-		spark.stop();
-		spark = null;
-
-		// clear status mlcontext and spark exec context
-		ml.close();
-		ml = null;
-	}
 }