You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by de...@apache.org on 2017/05/26 06:52:12 UTC
[1/4] incubator-systemml git commit: [SYSTEMML-1303] Remove
deprecated old MLContext API
Repository: incubator-systemml
Updated Branches:
refs/heads/master 0a89676fa -> 7ba17c7f6
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
index b8aa11e..ddd0f6d 100644
--- a/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysml/test/integration/AutomatedTestBase.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -37,13 +37,8 @@ import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.SparkSession.Builder;
-import org.apache.wink.json4j.JSONObject;
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.MLContext;
import org.apache.sysml.conf.DMLConfig;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.lops.Lop;
@@ -51,8 +46,6 @@ import org.apache.sysml.parser.DataExpression;
import org.apache.sysml.parser.Expression.DataType;
import org.apache.sysml.parser.Expression.ValueType;
import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.io.FrameReader;
import org.apache.sysml.runtime.io.FrameReaderFactory;
@@ -66,6 +59,10 @@ import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.test.utils.TestUtils;
import org.apache.sysml.utils.ParameterBuilder;
import org.apache.sysml.utils.Statistics;
+import org.apache.wink.json4j.JSONObject;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
/**
@@ -79,38 +76,38 @@ import org.apache.sysml.utils.Statistics;
* <li>check results</li>
* <li>clean up after test run</li>
* </ul>
- *
+ *
*/
@SuppressWarnings("deprecation")
-public abstract class AutomatedTestBase
+public abstract class AutomatedTestBase
{
-
+
public enum ScriptType {
DML, PYDML;
-
+
public String lowerCase() {
return super.toString().toLowerCase();
}
};
-
+
public static final boolean EXCEPTION_EXPECTED = true;
public static final boolean EXCEPTION_NOT_EXPECTED = false;
-
- // By default: TEST_GPU is set to false to allow developers without Nvidia GPU to run integration test suite
+
+ // By default: TEST_GPU is set to false to allow developers without Nvidia GPU to run integration test suite
public static final boolean TEST_GPU = false;
public static final double GPU_TOLERANCE = 1e-9;
-
+
protected ScriptType scriptType;
-
+
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
- // Hadoop 2.4.1 doesn't work on Windows unless winutils.exe is available
+ // Hadoop 2.4.1 doesn't work on Windows unless winutils.exe is available
// under $HADOOP_HOME/bin and hadoop.dll is available in the Java library
- // path. The following static initializer sets up JVM variables so that
+ // path. The following static initializer sets up JVM variables so that
// Hadoop can find these native binaries, assuming that any Hadoop code
// loads after this class and that the JVM's current working directory
// is the root of this project.
static {
-
+
String osname = System.getProperty("os.name").toLowerCase();
if (osname.contains("win")) {
System.err.printf("AutomatedTestBase has detected a Windows OS and is overriding\n"
@@ -119,9 +116,9 @@ public abstract class AutomatedTestBase
System.setProperty("hadoop.home.dir", cwd + File.separator
+ "\\src\\test\\config\\hadoop_bin_windows");
-
+
if(TEST_GPU) {
- String CUDA_LIBRARY_PATH = System.getenv("CUDA_PATH") + File.separator + "bin";
+ String CUDA_LIBRARY_PATH = System.getenv("CUDA_PATH") + File.separator + "bin";
System.setProperty("java.library.path", cwd + File.separator
+ "\\src\\test\\config\\hadoop_bin_windows\\bin" + File.pathSeparator
+ "/lib" + File.pathSeparator
@@ -130,18 +127,18 @@ public abstract class AutomatedTestBase
else {
System.setProperty("java.library.path", cwd + File.separator
+ "\\src\\test\\config\\hadoop_bin_windows\\bin"
- // For testing BLAS on Windows
+ // For testing BLAS on Windows
// + File.pathSeparator + "C:\\Program Files (x86)\\IntelSWTools\\compilers_and_libraries_2017.0.109\\windows\\redist\\intel64_win\\mkl"
);
}
-
+
// Need to muck around with the classloader to get it to use the new
// value of java.library.path.
try {
final Field sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
sysPathsField.setAccessible(true);
-
+
sysPathsField.set(null, null);
} catch (Exception e) {
// IBM Java throws an exception here, so don't print the stack trace.
@@ -151,7 +148,7 @@ public abstract class AutomatedTestBase
}
}
// *** END HACK ***
-
+
/**
* Script source directory for .dml and .r files only
* (TEST_DATA_DIR for generated test data artifacts).
@@ -160,19 +157,19 @@ public abstract class AutomatedTestBase
protected static final String INPUT_DIR = "in/";
protected static final String OUTPUT_DIR = "out/";
protected static final String EXPECTED_DIR = "expected/";
-
+
/** Location where this class writes files for inspection if DEBUG is set to true. */
private static final String DEBUG_TEMP_DIR = "./tmp/";
-
+
/** Directory under which config files shared across tests are located. */
private static final String CONFIG_DIR = "./src/test/config/";
-
+
/**
* Location of the SystemML config file that we use as a template when
* generating the configs for each test case.
*/
private static final File CONFIG_TEMPLATE_FILE = new File(CONFIG_DIR, "SystemML-config.xml");
-
+
/**
* Location under which we create local temporary directories for test cases.
* To adjust where testTemp is located, use -Dsystemml.testTemp.root.dir=<new location>. This is necessary
@@ -180,7 +177,7 @@ public abstract class AutomatedTestBase
*/
private static final String LOCAL_TEMP_ROOT_DIR = System.getProperty("systemml.testTemp.root.dir","target/testTemp");
private static final File LOCAL_TEMP_ROOT = new File(LOCAL_TEMP_ROOT_DIR);
-
+
/** Base directory for generated IN, OUT, EXPECTED test data artifacts instead of SCRIPT_DIR. */
protected static final String TEST_DATA_DIR = LOCAL_TEMP_ROOT_DIR + "/";
protected static final boolean TEST_CACHE_ENABLED = true;
@@ -191,23 +188,23 @@ public abstract class AutomatedTestBase
* Runtime backend to use for all integration tests. Some individual tests
* override this value, but the rest will use whatever is the default here.
* <p>
- * Also set DMLScript.USE_LOCAL_SPARK_CONFIG to true for running the test
+ * Also set DMLScript.USE_LOCAL_SPARK_CONFIG to true for running the test
* suite in spark mode
*/
protected static RUNTIME_PLATFORM rtplatform = RUNTIME_PLATFORM.HYBRID;
-
+
protected static final boolean DEBUG = false;
protected static final boolean VISUALIZE = false;
protected static final boolean RUNNETEZZA = false;
-
+
protected String fullDMLScriptName; // utilize for both DML and PyDML, should probably be renamed.
// protected String fullPYDMLScriptName;
protected String fullRScriptName;
-
+
protected static String baseDirectory;
protected static String sourceDirectory;
protected HashMap<String, TestConfiguration> availableTestConfigurations;
-
+
/* For testing in the old way */
protected HashMap<String, String> testVariables; /* variables and their values */
@@ -215,16 +212,16 @@ public abstract class AutomatedTestBase
//protected String[] dmlArgs; /* program-independent arguments to SystemML (e.g., debug, execution mode) */
protected String[] programArgs; /* program-specific arguments, which are passed to SystemML via -args option */
protected String rCmd; /* Rscript foo.R arg1, arg2 ... */
-
+
protected String selectedTest;
protected String[] outputDirectories;
protected String[] comparisonFiles;
protected ArrayList<String> inputDirectories;
protected ArrayList<String> inputRFiles;
protected ArrayList<String> expectedFiles;
-
+
private File curLocalTempDir = null;
-
+
private boolean isOutAndExpectedDeletionDisabled = false;
@@ -247,7 +244,7 @@ public abstract class AutomatedTestBase
* <p>
* Adds a test configuration to the list of available test configurations.
* </p>
- *
+ *
* @param testName
* test name
* @param config
@@ -262,7 +259,7 @@ public abstract class AutomatedTestBase
* Adds a test configuration to the list of available test configurations based
* on the test directory and the test name.
* </p>
- *
+ *
* @param testDirectory
* test directory
* @param testName
@@ -272,8 +269,8 @@ public abstract class AutomatedTestBase
TestConfiguration config = new TestConfiguration(testDirectory, testName);
availableTestConfigurations.put(testName, config);
}
-
-
+
+
@Before
public final void setUpBase() {
availableTestConfigurations = new HashMap<String, TestConfiguration>();
@@ -284,7 +281,7 @@ public abstract class AutomatedTestBase
outputDirectories = new String[0];
setOutAndExpectedDeletionDisabled(false);
lTimeBeforeTest = System.currentTimeMillis();
-
+
TestUtils.clearAssertionInformation();
}
@@ -293,7 +290,7 @@ public abstract class AutomatedTestBase
* Returns a test configuration from the list of available configurations.
* If no configuration is added for the specified name, the test will fail.
* </p>
- *
+ *
* @param testName
* test name
* @return test configuration
@@ -304,15 +301,15 @@ public abstract class AutomatedTestBase
return availableTestConfigurations.get(testName);
}
-
+
/**
* <p>
* Gets a test configuration from the list of available configurations
- * and loads it if it's available. It is then returned.
+ * and loads it if it's available. It is then returned.
* If no configuration exists for the specified name, the test will fail.
- *
+ *
* </p>
- *
+ *
* @param testName
* test name
* @return test configuration
@@ -322,11 +319,11 @@ public abstract class AutomatedTestBase
loadTestConfiguration(testConfiguration);
return testConfiguration;
}
-
+
/**
* Subclasses must call {@link #loadTestConfiguration(TestConfiguration)}
* before calling this method.
- *
+ *
* @return the directory where the current test case should write temp
* files. This directory also contains the current test's customized
* SystemML config file.
@@ -338,17 +335,17 @@ public abstract class AutomatedTestBase
}
return curLocalTempDir;
}
-
+
/**
* Subclasses must call {@link #loadTestConfiguration(TestConfiguration)}
* before calling this method.
- *
+ *
* @return the location of the current test case's SystemML config file
*/
protected File getCurConfigFile() {
return new File(getCurLocalTempDir(), "SystemML-config.xml");
}
-
+
/**
* <p>
* Tests that use custom SystemML configuration should override to ensure
@@ -358,32 +355,13 @@ public abstract class AutomatedTestBase
protected File getConfigTemplateFile() {
return CONFIG_TEMPLATE_FILE;
}
-
- protected MLContext getMLContextForTesting() throws DMLRuntimeException {
- synchronized(AutomatedTestBase.class) {
-
- RUNTIME_PLATFORM oldRT = DMLScript.rtplatform;
- try {
- DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
- ExecutionContext ec = ExecutionContextFactory.createContext();
- if(ec instanceof SparkExecutionContext) {
- MLContext mlCtx = new MLContext(((SparkExecutionContext) ec).getSparkContext());
- return mlCtx;
- }
- }
- finally {
- DMLScript.rtplatform = oldRT;
- }
- throw new DMLRuntimeException("Cannot create MLContext");
- }
- }
/**
* <p>
* Generates a random matrix with the specified characteristics and returns
* it as a two dimensional array.
* </p>
- *
+ *
* @param rows
* number of rows
* @param cols
@@ -407,7 +385,7 @@ public abstract class AutomatedTestBase
* Generates a random matrix with the specified characteristics which does
* not contain any zero values and returns it as a two dimensional array.
* </p>
- *
+ *
* @param rows
* number of rows
* @param cols
@@ -429,7 +407,7 @@ public abstract class AutomatedTestBase
* Generates a random matrix with the specified characteristics and writes
* it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param rows
@@ -455,7 +433,7 @@ public abstract class AutomatedTestBase
* Generates a random matrix with the specified characteristics and writes
* it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param rows
@@ -488,7 +466,7 @@ public abstract class AutomatedTestBase
private void cleanupExistingData(String fname, boolean cleanupRData) throws IOException {
MapReduceTool.deleteFileIfExistOnHDFS(fname);
MapReduceTool.deleteFileIfExistOnHDFS(fname + ".mtd");
- if ( cleanupRData )
+ if ( cleanupRData )
MapReduceTool.deleteFileIfExistOnHDFS(fname + ".mtx");
}
@@ -496,7 +474,7 @@ public abstract class AutomatedTestBase
* <p>
* Adds a matrix to the input path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -507,14 +485,14 @@ public abstract class AutomatedTestBase
protected double[][] writeInputMatrix(String name, double[][] matrix, boolean bIncludeR) {
String completePath = baseDirectory + INPUT_DIR + name + "/in";
String completeRPath = baseDirectory + INPUT_DIR + name + ".mtx";
-
+
try {
cleanupExistingData(baseDirectory + INPUT_DIR + name, bIncludeR);
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
-
+
TestUtils.writeTestMatrix(completePath, matrix);
if (bIncludeR) {
TestUtils.writeTestMatrix(completeRPath, matrix, true);
@@ -527,22 +505,22 @@ public abstract class AutomatedTestBase
return matrix;
}
- protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR)
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR)
{
MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc);
}
-
- protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, int nnz, boolean bIncludeR)
+
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, int nnz, boolean bIncludeR)
{
MatrixCharacteristics mc = new MatrixCharacteristics(matrix.length, matrix[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE, nnz);
return writeInputMatrixWithMTD(name, matrix, bIncludeR, mc);
}
-
- protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, MatrixCharacteristics mc)
+
+ protected double[][] writeInputMatrixWithMTD(String name, double[][] matrix, boolean bIncludeR, MatrixCharacteristics mc)
{
writeInputMatrix(name, matrix, bIncludeR);
-
+
// write metadata file
try
{
@@ -554,15 +532,15 @@ public abstract class AutomatedTestBase
e.printStackTrace();
throw new RuntimeException(e);
}
-
+
return matrix;
}
-
+
/**
* <p>
* Adds a matrix to the input path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -576,7 +554,7 @@ public abstract class AutomatedTestBase
* <p>
* Adds a matrix to the input path and writes it to a file in binary format.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -591,14 +569,14 @@ public abstract class AutomatedTestBase
protected void writeInputBinaryMatrix(String name, double[][] matrix, int rowsInBlock, int colsInBlock,
boolean sparseFormat) {
String completePath = baseDirectory + INPUT_DIR + name + "/in";
-
+
try {
cleanupExistingData(baseDirectory + INPUT_DIR + name, false);
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
-
+
if (rowsInBlock == 1 && colsInBlock == 1) {
TestUtils.writeBinaryTestMatrixCells(completePath, matrix);
if (DEBUG)
@@ -613,8 +591,8 @@ public abstract class AutomatedTestBase
}
/**
- * Writes the given matrix to input path, and writes the associated metadata file.
- *
+ * Writes the given matrix to input path, and writes the associated metadata file.
+ *
* @param name
* @param matrix
* @param rowsInBlock
@@ -635,7 +613,7 @@ public abstract class AutomatedTestBase
* <p>
* Adds a matrix to the expectation path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -650,7 +628,7 @@ public abstract class AutomatedTestBase
* <p>
* Adds a matrix to the expectation path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -668,7 +646,7 @@ public abstract class AutomatedTestBase
* Adds a matrix to the expectation path and writes it to a file in binary
* format.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
@@ -705,7 +683,7 @@ public abstract class AutomatedTestBase
* <p>
* Creates a expectation helper matrix which can be used to compare scalars.
* </p>
- *
+ *
* @param name
* file name
* @param value
@@ -732,36 +710,36 @@ public abstract class AutomatedTestBase
System.out.println("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
return TestUtils.readRMatrixFromFS(baseDirectory + EXPECTED_DIR + cacheDir + fileName);
}
-
+
protected static HashMap<CellIndex, Double> readDMLScalarFromHDFS(String fileName) {
return TestUtils.readDMLScalarFromHDFS(baseDirectory + OUTPUT_DIR + fileName);
}
-
-
- protected static FrameBlock readDMLFrameFromHDFS(String fileName, InputInfo iinfo)
- throws DMLRuntimeException, IOException
+
+
+ protected static FrameBlock readDMLFrameFromHDFS(String fileName, InputInfo iinfo)
+ throws DMLRuntimeException, IOException
{
//read frame data from hdfs
String strFrameFileName = baseDirectory + OUTPUT_DIR + fileName;
FrameReader reader = FrameReaderFactory.createFrameReader(iinfo);
-
+
MatrixCharacteristics md = readDMLMetaDataFile(fileName);
return reader.readFrameFromHDFS(strFrameFileName, md.getRows(), md.getCols());
}
- protected static FrameBlock readDMLFrameFromHDFS(String fileName, InputInfo iinfo, MatrixCharacteristics md)
- throws DMLRuntimeException, IOException
+ protected static FrameBlock readDMLFrameFromHDFS(String fileName, InputInfo iinfo, MatrixCharacteristics md)
+ throws DMLRuntimeException, IOException
{
//read frame data from hdfs
String strFrameFileName = baseDirectory + OUTPUT_DIR + fileName;
FrameReader reader = FrameReaderFactory.createFrameReader(iinfo);
-
+
return reader.readFrameFromHDFS(strFrameFileName, md.getRows(), md.getCols());
}
- protected static FrameBlock readRFrameFromHDFS(String fileName, InputInfo iinfo, MatrixCharacteristics md)
- throws DMLRuntimeException, IOException
+ protected static FrameBlock readRFrameFromHDFS(String fileName, InputInfo iinfo, MatrixCharacteristics md)
+ throws DMLRuntimeException, IOException
{
//read frame data from hdfs
String strFrameFileName = baseDirectory + EXPECTED_DIR + fileName;
@@ -769,7 +747,7 @@ public abstract class AutomatedTestBase
CSVFileFormatProperties fprop = new CSVFileFormatProperties();
fprop.setHeader(true);
FrameReader reader = FrameReaderFactory.createFrameReader(iinfo, fprop);
-
+
return reader.readFrameFromHDFS(strFrameFileName, md.getRows(), md.getCols());
}
@@ -779,18 +757,13 @@ public abstract class AutomatedTestBase
System.out.println("R script out: " + baseDirectory + EXPECTED_DIR + cacheDir + fileName);
return TestUtils.readRScalarFromFS(baseDirectory + EXPECTED_DIR + cacheDir + fileName);
}
-
- /**
- *
- * @param fileName
- * @param mc
- */
+
public static void checkDMLMetaDataFile(String fileName, MatrixCharacteristics mc) {
MatrixCharacteristics rmc = readDMLMetaDataFile(fileName);
Assert.assertEquals(mc.getRows(), rmc.getRows());
Assert.assertEquals(mc.getCols(), rmc.getCols());
}
-
+
public static MatrixCharacteristics readDMLMetaDataFile(String fileName)
{
try {
@@ -804,7 +777,7 @@ public abstract class AutomatedTestBase
throw new RuntimeException(ex);
}
}
-
+
public static ValueType readDMLMetaDataValueType(String fileName)
{
try {
@@ -816,29 +789,29 @@ public abstract class AutomatedTestBase
throw new RuntimeException(ex);
}
}
-
+
/**
* <p>
* Loads a test configuration with its parameters. Adds the output
* directories to the output list as well as to the list of possible
* comparison files.
* </p>
- *
+ *
* @param config
* test configuration name
- *
+ *
*/
protected void loadTestConfiguration(TestConfiguration config) {
loadTestConfiguration(config, null);
}
-
+
/**
* <p>
* Loads a test configuration with its parameters. Adds the output
* directories to the output list as well as to the list of possible
* comparison files.
* </p>
- *
+ *
* @param config
* test configuration name
* @param cacheDirectory
@@ -855,11 +828,11 @@ public abstract class AutomatedTestBase
sourceDirectory = SCRIPT_DIR + getSourceDirectory(testDirectory);
}
else {
- baseDirectory = SCRIPT_DIR + testDirectory;
+ baseDirectory = SCRIPT_DIR + testDirectory;
sourceDirectory = baseDirectory;
}
}
-
+
setCacheDirectory(cacheDirectory);
selectedTest = config.getTestScript();
@@ -881,7 +854,7 @@ public abstract class AutomatedTestBase
testVariables.put("readhelper", "Helper = read(\"" + baseDirectory + INPUT_DIR + "helper/in\", "
+ "rows=1, cols=2, format=\"text\");");
testVariables.put("Routdir", baseDirectory + EXPECTED_DIR + cacheDir);
-
+
// Create a temporary directory for this test case.
// Eventually all files written by the tests should go under here, but making
// that change will take quite a bit of effort.
@@ -895,22 +868,22 @@ public abstract class AutomatedTestBase
curLocalTempDir = new File(LOCAL_TEMP_ROOT, String.format(
"%s/%s", testDirectory, selectedTest));
}
-
+
curLocalTempDir.mkdirs();
TestUtils.clearDirectory(curLocalTempDir.getPath());
// Create a SystemML config file for this test case based on default template
// from src/test/config or derive from custom configuration provided by test.
String configTemplate = FileUtils.readFileToString(getConfigTemplateFile(), "UTF-8");
-
+
String localTemp = curLocalTempDir.getPath();
- String configContents = configTemplate.replace("<scratch>scratch_space</scratch>",
+ String configContents = configTemplate.replace("<scratch>scratch_space</scratch>",
String.format("<scratch>%s/scratch_space</scratch>", localTemp));
- configContents = configContents.replace("<localtmpdir>/tmp/systemml</localtmpdir>",
+ configContents = configContents.replace("<localtmpdir>/tmp/systemml</localtmpdir>",
String.format("<localtmpdir>%s/localtmp</localtmpdir>", localTemp));
-
+
FileUtils.write(getCurConfigFile(), configContents, "UTF-8");
-
+
System.out.printf("This test case will use SystemML config file %s\n", getCurConfigFile());
} catch (IOException e) {
throw new RuntimeException(e);
@@ -919,8 +892,8 @@ public abstract class AutomatedTestBase
if (DEBUG)
TestUtils.clearDirectory(DEBUG_TEMP_DIR + baseDirectory + INPUT_DIR);
}
-
-
+
+
/**
* <p>
@@ -928,10 +901,10 @@ public abstract class AutomatedTestBase
* directories to the output list as well as to the list of possible
* comparison files.
* </p>
- *
+ *
* @param configurationName
* test configuration name
- *
+ *
*/
@Deprecated
protected void loadTestConfiguration(String configurationName) {
@@ -943,27 +916,27 @@ public abstract class AutomatedTestBase
loadTestConfiguration(config);
}
- /**
+ /**
* Runs an R script, default to the old way
*/
protected void runRScript() {
runRScript(false);
-
+
}
/**
* Runs an R script in the old or the new way
*/
protected void runRScript(boolean newWay) {
-
- String executionFile = sourceDirectory + selectedTest + ".R";
-
+
+ String executionFile = sourceDirectory + selectedTest + ".R";
+
// *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
// Some of the R scripts will fail if the "expected" directory doesn't exist.
// Make sure the directory exists.
File expectedDir = new File(baseDirectory, "expected" + "/" + cacheDir);
expectedDir.mkdirs();
// *** END HACK ***
-
+
String cmd;
if( !newWay ) {
executionFile = executionFile + "t";
@@ -982,9 +955,9 @@ public abstract class AutomatedTestBase
"Rscript --default-packages=methods,datasets,graphics,grDevices,stats,utils");
// *** END HACK ***
}
-
+
if (System.getProperty("os.name").contains("Windows")) {
- cmd = cmd.replace('/', '\\');
+ cmd = cmd.replace('/', '\\');
executionFile = executionFile.replace('/', '\\');
}
if (DEBUG) {
@@ -995,7 +968,7 @@ public abstract class AutomatedTestBase
if( !newWay ) {
ParameterBuilder.setVariablesInScript(sourceDirectory, selectedTest + ".R", testVariables);
}
-
+
if (cacheDir.length() > 0)
{
File expectedFile = null;
@@ -1005,7 +978,7 @@ public abstract class AutomatedTestBase
{
outputFiles = testConfig.getOutputFiles();
}
-
+
if (outputFiles != null && outputFiles.length > 0)
{
expectedFile = new File (expectedDir.getPath() + "/" + outputFiles[0]);
@@ -1016,12 +989,12 @@ public abstract class AutomatedTestBase
}
}
}
-
+
try {
long t0 = System.nanoTime();
System.out.println("starting R script");
- System.out.println("cmd: " + cmd);
- Process child = Runtime.getRuntime().exec(cmd);
+ System.out.println("cmd: " + cmd);
+ Process child = Runtime.getRuntime().exec(cmd);
String outputR = IOUtils.toString(child.getInputStream());
System.out.println("Standard Output from R:" + outputR);
@@ -1081,7 +1054,7 @@ public abstract class AutomatedTestBase
* Runs a test for which no exception is expected. If SystemML executes more
* MR jobs than specified in maxMRJobs this test will fail.
* </p>
- *
+ *
* @param maxMRJobs
* specifies a maximum limit for the number of MR jobs. If set to
* -1 there is no limit.
@@ -1094,7 +1067,7 @@ public abstract class AutomatedTestBase
* <p>
* Runs a test for which the exception expectation can be specified.
* </p>
- *
+ *
* @param exceptionExpected
* exception expected
*/
@@ -1116,7 +1089,7 @@ public abstract class AutomatedTestBase
protected void runTest(boolean exceptionExpected, Class<?> expectedException) {
runTest(exceptionExpected, expectedException, -1);
}
-
+
/**
* <p>
* Runs a test for which the exception expectation can be specified as well
@@ -1135,7 +1108,7 @@ public abstract class AutomatedTestBase
protected void runTest(boolean exceptionExpected, Class<?> expectedException, int maxMRJobs) {
runTest(false, exceptionExpected, expectedException, maxMRJobs);
}
-
+
/**
* <p>
* Runs a test for which the exception expectation can be specified as well
@@ -1153,24 +1126,24 @@ public abstract class AutomatedTestBase
* -1 there is no limit.
*/
protected void runTest(boolean newWay, boolean exceptionExpected, Class<?> expectedException, int maxMRJobs) {
-
+
String executionFile = sourceDirectory + selectedTest + ".dml";
-
+
if( !newWay ) {
executionFile = executionFile + "t";
ParameterBuilder.setVariablesInScript(sourceDirectory, selectedTest + ".dml", testVariables);
}
-
+
//cleanup scratch folder (prevent side effect between tests)
cleanupScratchSpace();
-
+
ArrayList<String> args = new ArrayList<String>();
// setup arguments to SystemML
-
+
if (DEBUG) {
args.add("-Dsystemml.logging=trace");
}
-
+
if (scriptType != null) { // DML/PYDML tests have newWay==true and a non-null scriptType
switch (scriptType) {
case DML:
@@ -1181,7 +1154,7 @@ public abstract class AutomatedTestBase
}
break;
case PYDML:
- if (null != fullDMLScriptName) {
+ if (null != fullDMLScriptName) {
args.add("-f");
args.add(fullDMLScriptName);
}
@@ -1219,16 +1192,16 @@ public abstract class AutomatedTestBase
//use optional config file since default under SystemML/DML
args.add("-config");
args.add(getCurConfigFile().getPath());
-
+
if(TEST_GPU)
args.add("-gpu");
-
+
// program-specific parameters
if ( newWay ) {
for (int i=0; i < programArgs.length; i++)
args.add(programArgs[i]);
}
-
+
if (DEBUG) {
if ( !newWay )
@@ -1243,12 +1216,12 @@ public abstract class AutomatedTestBase
}
}
}
-
+
try {
String [] dmlScriptArgs = args.toArray(new String[args.size()]);
System.out.println("arguments to DMLScript: " + Arrays.toString(dmlScriptArgs));
DMLScript.main(dmlScriptArgs);
-
+
/** check number of MR jobs */
if (maxMRJobs > -1 && maxMRJobs < Statistics.getNoOfCompiledMRJobs())
fail("Limit of MR jobs is exceeded: expected: " + maxMRJobs + ", occurred: "
@@ -1271,20 +1244,20 @@ public abstract class AutomatedTestBase
}
}
}
-
+
public void cleanupScratchSpace()
{
- try
+ try
{
//parse config file
DMLConfig conf = new DMLConfig(getCurConfigFile().getPath());
// delete the scratch_space and all contents
// (prevent side effect between tests)
- String dir = conf.getTextValue(DMLConfig.SCRATCH_SPACE);
+ String dir = conf.getTextValue(DMLConfig.SCRATCH_SPACE);
MapReduceTool.deleteFileIfExistOnHDFS(dir);
- }
- catch (Exception ex)
+ }
+ catch (Exception ex)
{
//ex.printStackTrace();
return; //no effect on tests
@@ -1331,7 +1304,7 @@ public abstract class AutomatedTestBase
* Compares the results of the computation with the expected ones with a
* specified tolerance.
* </p>
- *
+ *
* @param epsilon
* tolerance
*/
@@ -1358,7 +1331,7 @@ public abstract class AutomatedTestBase
* Compares the results of the computation with the expected ones with a
* specified tolerance.
* </p>
- *
+ *
* @param epsilon
* tolerance
*/
@@ -1375,8 +1348,8 @@ public abstract class AutomatedTestBase
}
}
}
-
-
+
+
/**
* Compare results of the computation with the expected results where rows may be permuted.
* @param epsilon
@@ -1395,23 +1368,23 @@ public abstract class AutomatedTestBase
}
}
}
-
+
/**
* Checks that the number of map-reduce jobs that the current test case has
* compiled is equal to the expected number. Generates a JUnit error message
* if the number is out of line.
- *
+ *
* @param expectedNumCompiled
* number of map-reduce jobs that the current test case is
* expected to compile
*/
protected void checkNumCompiledMRJobs(int expectedNumCompiled) {
-
+
if( OptimizerUtils.isSparkExecutionMode() ) {
// Skip MapReduce-related checks when running in Spark mode.
return;
}
-
+
assertEquals("Unexpected number of compiled MR jobs.",
expectedNumCompiled, Statistics.getNoOfCompiledMRJobs());
}
@@ -1421,18 +1394,18 @@ public abstract class AutomatedTestBase
* executed (as opposed to compiling into the execution plan) is equal to
* the expected number. Generates a JUnit error message if the number is out
* of line.
- *
+ *
* @param expectedNumExecuted
* number of map-reduce jobs that the current test case is
* expected to run
*/
protected void checkNumExecutedMRJobs(int expectedNumExecuted) {
-
+
if( OptimizerUtils.isSparkExecutionMode() ) {
// Skip MapReduce-related checks when running in Spark mode.
return;
}
-
+
assertEquals("Unexpected number of executed MR jobs.",
expectedNumExecuted, Statistics.getNoOfExecutedMRJobs());
}
@@ -1470,7 +1443,7 @@ public abstract class AutomatedTestBase
* <p>
* Checks the results of a computation against a number of characteristics.
* </p>
- *
+ *
* @param rows
* number of rows
* @param cols
@@ -1501,7 +1474,7 @@ public abstract class AutomatedTestBase
public void tearDown() {
System.out.println("Duration: " + (System.currentTimeMillis() - lTimeBeforeTest) + "ms");
-
+
assertTrue("expected String did not occur: " + expectedStdOut, iExpectedStdOutState == 0
|| iExpectedStdOutState == 2);
assertTrue("expected String did not occur (stderr): " + expectedStdErr, iExpectedStdErrState == 0
@@ -1511,9 +1484,9 @@ public abstract class AutomatedTestBase
if (!isOutAndExpectedDeletionDisabled()) {
- TestUtils.removeHDFSDirectories(inputDirectories.toArray(new String[inputDirectories.size()]));
+ TestUtils.removeHDFSDirectories(inputDirectories.toArray(new String[inputDirectories.size()]));
TestUtils.removeFiles(inputRFiles.toArray(new String[inputRFiles.size()]));
-
+
// The following cleanup code is disabled (see [SYSML-256]) until we can figure out
// what test cases are creating temporary directories at the root of the project.
//TestUtils.removeTemporaryFiles();
@@ -1538,7 +1511,7 @@ public abstract class AutomatedTestBase
/**
* Enables detection of expected output of a line in standard output stream.
- *
+ *
* @param expectedLine
*/
public void setExpectedStdOut(String expectedLine) {
@@ -1551,7 +1524,7 @@ public abstract class AutomatedTestBase
/**
* This class is used to compare the standard output stream against an
* expected string.
- *
+ *
*/
class ExpectedOutputStream extends OutputStream {
private String line = "";
@@ -1583,7 +1556,7 @@ public abstract class AutomatedTestBase
/**
* This class is used to compare the standard error stream against an
* expected string.
- *
+ *
*/
class ExpectedErrorStream extends OutputStream {
private String line = "";
@@ -1642,7 +1615,7 @@ public abstract class AutomatedTestBase
* <p>
* Generates a matrix containing easy to debug values in its cells.
* </p>
- *
+ *
* @param rows
* @param cols
* @param bContainsZeros
@@ -1659,7 +1632,7 @@ public abstract class AutomatedTestBase
* Generates a matrix containing easy to debug values in its cells. The
* generated matrix contains zero values
* </p>
- *
+ *
* @param rows
* @param cols
* @return
@@ -1685,39 +1658,39 @@ public abstract class AutomatedTestBase
boolean isOutAndExpectedDeletionDisabled) {
this.isOutAndExpectedDeletionDisabled = isOutAndExpectedDeletionDisabled;
}
-
+
protected String input(String input) {
return baseDirectory + INPUT_DIR + input;
}
-
+
protected String inputDir() {
return baseDirectory + INPUT_DIR;
}
-
+
protected String output(String output) {
return baseDirectory + OUTPUT_DIR + output;
}
-
+
protected String outputDir() {
return baseDirectory + OUTPUT_DIR;
}
-
+
protected String expected(String expected) {
return baseDirectory + EXPECTED_DIR + cacheDir + expected;
}
-
+
protected String expectedDir() {
return baseDirectory + EXPECTED_DIR + cacheDir;
}
-
+
protected String getScript() {
return sourceDirectory + selectedTest + "." + scriptType.lowerCase();
}
-
+
protected String getRScript() {
return sourceDirectory + selectedTest + ".R";
}
-
+
protected String getRCmd(String ... args) {
StringBuilder sb = new StringBuilder();
sb.append("Rscript ");
@@ -1728,21 +1701,21 @@ public abstract class AutomatedTestBase
}
return sb.toString();
}
-
+
private boolean isTargetTestDirectory(String path) {
return (path != null && path.contains(getClass().getSimpleName()));
}
-
+
private void setCacheDirectory(String directory) {
cacheDir = (directory != null) ? directory : "";
if (cacheDir.length() > 0 && !cacheDir.endsWith("/")) {
cacheDir += "/";
}
}
-
+
private String getSourceDirectory(String testDirectory) {
String sourceDirectory = "";
-
+
if (null != testDirectory)
{
if (testDirectory.endsWith("/"))
@@ -1751,37 +1724,37 @@ public abstract class AutomatedTestBase
}
sourceDirectory = testDirectory.substring(0, testDirectory.lastIndexOf("/") + "/".length());
}
-
+
return sourceDirectory;
}
-
+
/**
* <p>
* Adds a frame to the input path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param data
* two dimensional frame data
* @param bIncludeR
* generates also the corresponding R frame data
- * @throws IOException
- * @throws DMLRuntimeException
+ * @throws IOException
+ * @throws DMLRuntimeException
*/
- protected double[][] writeInputFrame(String name, double[][] data, boolean bIncludeR, ValueType[] schema, OutputInfo oi)
- throws DMLRuntimeException, IOException
+ protected double[][] writeInputFrame(String name, double[][] data, boolean bIncludeR, ValueType[] schema, OutputInfo oi)
+ throws DMLRuntimeException, IOException
{
String completePath = baseDirectory + INPUT_DIR + name;
String completeRPath = baseDirectory + INPUT_DIR + name + ".csv";
-
+
try {
cleanupExistingData(baseDirectory + INPUT_DIR + name, bIncludeR);
} catch (IOException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
-
+
TestUtils.writeTestFrame(completePath, data, schema, oi);
if (bIncludeR) {
TestUtils.writeTestFrame(completeRPath, data, schema, OutputInfo.CSVOutputInfo, true);
@@ -1794,18 +1767,18 @@ public abstract class AutomatedTestBase
return data;
}
- protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, ValueType[] schema, OutputInfo oi)
- throws DMLRuntimeException, IOException
+ protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, ValueType[] schema, OutputInfo oi)
+ throws DMLRuntimeException, IOException
{
MatrixCharacteristics mc = new MatrixCharacteristics(data.length, data[0].length, OptimizerUtils.DEFAULT_BLOCKSIZE, data[0].length, -1);
return writeInputFrameWithMTD(name, data, bIncludeR, mc, schema, oi);
}
-
- protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, MatrixCharacteristics mc, ValueType[] schema, OutputInfo oi)
- throws DMLRuntimeException, IOException
+
+ protected double[][] writeInputFrameWithMTD(String name, double[][] data, boolean bIncludeR, MatrixCharacteristics mc, ValueType[] schema, OutputInfo oi)
+ throws DMLRuntimeException, IOException
{
writeInputFrame(name, data, bIncludeR, schema, oi);
-
+
// write metadata file
try
{
@@ -1817,41 +1790,41 @@ public abstract class AutomatedTestBase
e.printStackTrace();
throw new RuntimeException(e);
}
-
+
return data;
}
-
+
/**
* <p>
* Adds a frame to the input path and writes it to a file.
* </p>
- *
+ *
* @param name
* directory name
* @param matrix
* two dimensional frame data
* @param schema
* @param oi
- * @throws IOException
- * @throws DMLRuntimeException
+ * @throws IOException
+ * @throws DMLRuntimeException
*/
- protected double[][] writeInputFrame(String name, double[][] data, ValueType[] schema, OutputInfo oi)
- throws DMLRuntimeException, IOException
+ protected double[][] writeInputFrame(String name, double[][] data, ValueType[] schema, OutputInfo oi)
+ throws DMLRuntimeException, IOException
{
return writeInputFrame(name, data, false, schema, oi);
}
-
+
protected boolean heavyHittersContainsSubString(String... str) {
for( String opcode : Statistics.getCPHeavyHitterOpCodes())
for( String s : str )
if(opcode.contains(s))
return true;
- return false;
+ return false;
}
/**
* Create a SystemML-preferred Spark Session.
- *
+ *
* @param appName the application name
* @param master the master value (ie, "local", etc)
* @return Spark Session
[2/4] incubator-systemml git commit: [SYSTEMML-1303] Remove
deprecated old MLContext API
Posted by de...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
index 1dd3600..06a2005 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -84,8 +84,8 @@ public class SparkExecutionContext extends ExecutionContext
{
private static final Log LOG = LogFactory.getLog(SparkExecutionContext.class.getName());
private static final boolean LDEBUG = false; //local debug flag
-
- //internal configurations
+
+ //internal configurations
private static boolean LAZY_SPARKCTX_CREATION = true;
private static boolean ASYNCHRONOUS_VAR_DESTROY = true;
@@ -93,16 +93,16 @@ public class SparkExecutionContext extends ExecutionContext
//executor memory and relative fractions as obtained from the spark configuration
private static SparkClusterConfig _sconf = null;
-
- //singleton spark context (as there can be only one spark context per JVM)
- private static JavaSparkContext _spctx = null;
-
- //registry of parallelized RDDs to enforce that at any time, we spent at most
+
+ //singleton spark context (as there can be only one spark context per JVM)
+ private static JavaSparkContext _spctx = null;
+
+ //registry of parallelized RDDs to enforce that at any time, we spent at most
//10% of JVM max heap size for parallelized RDDs; if this is not sufficient,
//matrices or frames are exported to HDFS and the RDDs are created from files.
//TODO unify memory management for CP, par RDDs, and potentially broadcasts
private static MemoryManagerParRDDs _parRDDs = new MemoryManagerParRDDs(0.1);
-
+
static {
// for internal debugging only
if( LDEBUG ) {
@@ -111,31 +111,31 @@ public class SparkExecutionContext extends ExecutionContext
}
}
- protected SparkExecutionContext(boolean allocateVars, Program prog)
+ protected SparkExecutionContext(boolean allocateVars, Program prog)
{
//protected constructor to force use of ExecutionContextFactory
super( allocateVars, prog );
-
+
//spark context creation via internal initializer
if( !(LAZY_SPARKCTX_CREATION && OptimizerUtils.isHybridExecutionMode()) ) {
initSparkContext();
}
}
-
+
/**
* Returns the used singleton spark context. In case of lazy spark context
* creation, this methods blocks until the spark context is created.
- *
+ *
* @return java spark context
*/
public JavaSparkContext getSparkContext()
{
- //lazy spark context creation on demand (lazy instead of asynchronous
+ //lazy spark context creation on demand (lazy instead of asynchronous
//to avoid wait for uninitialized spark context on close)
if( LAZY_SPARKCTX_CREATION ) {
initSparkContext();
}
-
+
//return the created spark context
return _spctx;
}
@@ -144,11 +144,11 @@ public class SparkExecutionContext extends ExecutionContext
initSparkContext();
return _spctx;
}
-
+
/**
* Indicates if the spark context has been created or has
* been passed in from outside.
- *
+ *
* @return true if spark context created
*/
public synchronized static boolean isSparkContextCreated() {
@@ -159,26 +159,25 @@ public class SparkExecutionContext extends ExecutionContext
_spctx = null;
}
- public void close()
+ public void close()
{
synchronized( SparkExecutionContext.class ) {
- if( _spctx != null )
+ if( _spctx != null )
{
//stop the spark context if existing
_spctx.stop();
-
+
//make sure stopped context is never used again
- _spctx = null;
+ _spctx = null;
}
-
+
}
}
-
+
public static boolean isLazySparkContextCreation(){
return LAZY_SPARKCTX_CREATION;
}
- @SuppressWarnings("deprecation")
private synchronized static void initSparkContext()
{
//check for redundant spark context init
@@ -186,24 +185,19 @@ public class SparkExecutionContext extends ExecutionContext
return;
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-
+
//create a default spark context (master, appname, etc refer to system properties
//as given in the spark configuration or during spark-submit)
-
+
Object mlCtxObj = MLContextProxy.getActiveMLContext();
- if(mlCtxObj != null)
+ if(mlCtxObj != null)
{
// This is when DML is called through spark shell
// Will clean the passing of static variables later as this involves minimal change to DMLScript
- if (mlCtxObj instanceof org.apache.sysml.api.MLContext) {
- org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj;
- _spctx = new JavaSparkContext(mlCtx.getSparkContext());
- } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) {
- org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
- _spctx = MLContextUtil.getJavaSparkContext(mlCtx);
- }
+ org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj;
+ _spctx = MLContextUtil.getJavaSparkContext(mlCtx);
}
- else
+ else
{
if(DMLScript.USE_LOCAL_SPARK_CONFIG) {
// For now set 4 cores for integration testing :)
@@ -220,128 +214,128 @@ public class SparkExecutionContext extends ExecutionContext
SparkConf conf = createSystemMLSparkConf();
_spctx = new JavaSparkContext(conf);
}
-
+
_parRDDs.clear();
}
-
- // Set warning if spark.driver.maxResultSize is not set. It needs to be set before starting Spark Context for CP collect
+
+ // Set warning if spark.driver.maxResultSize is not set. It needs to be set before starting Spark Context for CP collect
String strDriverMaxResSize = _spctx.getConf().get("spark.driver.maxResultSize", "1g");
- long driverMaxResSize = UtilFunctions.parseMemorySize(strDriverMaxResSize);
+ long driverMaxResSize = UtilFunctions.parseMemorySize(strDriverMaxResSize);
if (driverMaxResSize != 0 && driverMaxResSize<OptimizerUtils.getLocalMemBudget() && !DMLScript.USE_LOCAL_SPARK_CONFIG)
LOG.warn("Configuration parameter spark.driver.maxResultSize set to " + UtilFunctions.formatMemorySize(driverMaxResSize) + "."
- + " You can set it through Spark default configuration setting either to 0 (unlimited) or to available memory budget of size "
+ + " You can set it through Spark default configuration setting either to 0 (unlimited) or to available memory budget of size "
+ UtilFunctions.formatMemorySize((long)OptimizerUtils.getLocalMemBudget()) + ".");
-
+
//globally add binaryblock serialization framework for all hdfs read/write operations
- //TODO if spark context passed in from outside (mlcontext), we need to clean this up at the end
+ //TODO if spark context passed in from outside (mlcontext), we need to clean this up at the end
if( MRJobConfiguration.USE_BINARYBLOCK_SERIALIZATION )
MRJobConfiguration.addBinaryBlockSerializationFramework( _spctx.hadoopConfiguration() );
-
+
//statistics maintenance
if( DMLScript.STATISTICS ){
Statistics.setSparkCtxCreateTime(System.nanoTime()-t0);
}
- }
-
+ }
+
/**
* Sets up a SystemML-preferred Spark configuration based on the implicit
* default configuration (as passed via configurations from outside).
- *
+ *
* @return spark configuration
*/
public static SparkConf createSystemMLSparkConf() {
SparkConf conf = new SparkConf();
-
+
//always set unlimited result size (required for cp collect)
conf.set("spark.driver.maxResultSize", "0");
-
+
//always use the fair scheduler (for single jobs, it's equivalent to fifo
//but for concurrent jobs in parfor it ensures better data locality because
//round robin assignment mitigates the problem of 'sticky slots')
if( FAIR_SCHEDULER_MODE ) {
conf.set("spark.scheduler.mode", "FAIR");
}
-
+
//increase scheduler delay (usually more robust due to better data locality)
if( !conf.contains("spark.locality.wait") ) { //default 3s
conf.set("spark.locality.wait", "5s");
}
-
+
return conf;
}
/**
* Spark instructions should call this for all matrix inputs except broadcast
* variables.
- *
+ *
* @param varname variable name
* @return JavaPairRDD of MatrixIndexes-MatrixBlocks
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockRDDHandleForVariable( String varname )
- throws DMLRuntimeException
+ public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockRDDHandleForVariable( String varname )
+ throws DMLRuntimeException
{
return (JavaPairRDD<MatrixIndexes,MatrixBlock>) getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo);
}
-
+
/**
* Spark instructions should call this for all frame inputs except broadcast
* variables.
- *
+ *
* @param varname variable name
* @return JavaPairRDD of Longs-FrameBlocks
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockRDDHandleForVariable( String varname )
- throws DMLRuntimeException
+ public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockRDDHandleForVariable( String varname )
+ throws DMLRuntimeException
{
JavaPairRDD<Long,FrameBlock> out = (JavaPairRDD<Long,FrameBlock>) getRDDHandleForVariable( varname, InputInfo.BinaryBlockInputInfo);
return out;
}
- public JavaPairRDD<?,?> getRDDHandleForVariable( String varname, InputInfo inputInfo )
+ public JavaPairRDD<?,?> getRDDHandleForVariable( String varname, InputInfo inputInfo )
throws DMLRuntimeException
{
Data dat = getVariable(varname);
if( dat instanceof MatrixObject ) {
MatrixObject mo = getMatrixObject(varname);
- return getRDDHandleForMatrixObject(mo, inputInfo);
+ return getRDDHandleForMatrixObject(mo, inputInfo);
}
else if( dat instanceof FrameObject ) {
FrameObject fo = getFrameObject(varname);
- return getRDDHandleForFrameObject(fo, inputInfo);
+ return getRDDHandleForFrameObject(fo, inputInfo);
}
else {
throw new DMLRuntimeException("Failed to obtain RDD for data type other than matrix or frame.");
}
}
-
+
/**
- * This call returns an RDD handle for a given matrix object. This includes
- * the creation of RDDs for in-memory or binary-block HDFS data.
- *
+ * This call returns an RDD handle for a given matrix object. This includes
+ * the creation of RDDs for in-memory or binary-block HDFS data.
+ *
* @param mo matrix object
* @param inputInfo input info
* @return JavaPairRDD handle for a matrix object
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public JavaPairRDD<?,?> getRDDHandleForMatrixObject( MatrixObject mo, InputInfo inputInfo )
+ public JavaPairRDD<?,?> getRDDHandleForMatrixObject( MatrixObject mo, InputInfo inputInfo )
throws DMLRuntimeException
- {
+ {
//NOTE: MB this logic should be integrated into MatrixObject
- //However, for now we cannot assume that spark libraries are
- //always available and hence only store generic references in
+ //However, for now we cannot assume that spark libraries are
+ //always available and hence only store generic references in
//matrix object while all the logic is in the SparkExecContext
-
+
JavaSparkContext sc = getSparkContext();
JavaPairRDD<?,?> rdd = null;
//CASE 1: rdd already existing (reuse if checkpoint or trigger
- //pending rdd operations if not yet cached but prevent to re-evaluate
+ //pending rdd operations if not yet cached but prevent to re-evaluate
//rdd operations if already executed and cached
- if( mo.getRDDHandle()!=null
+ if( mo.getRDDHandle()!=null
&& (mo.getRDDHandle().isCheckpointRDD() || !mo.isCached(false)) )
{
//return existing rdd handling (w/o input format change)
@@ -359,7 +353,7 @@ public class SparkExecutionContext extends ExecutionContext
if( mo.isDirty() || !mo.isHDFSFileExists() ) //write if necessary
mo.exportData();
rdd = sc.hadoopFile( mo.getFileName(), inputInfo.inputFormatClass, inputInfo.inputKeyClass, inputInfo.inputValueClass);
- rdd = SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, MatrixBlock>)rdd); //cp is workaround for read bug
+ rdd = SparkUtils.copyBinaryBlockMatrix((JavaPairRDD<MatrixIndexes, MatrixBlock>)rdd); //cp is workaround for read bug
fromFile = true;
}
else { //default case
@@ -368,7 +362,7 @@ public class SparkExecutionContext extends ExecutionContext
mo.release(); //unpin matrix
_parRDDs.registerRDD(rdd.id(), OptimizerUtils.estimatePartitionedSizeExactSparsity(mc), true);
}
-
+
//keep rdd handle for future operations on it
RDDObject rddhandle = new RDDObject(rdd, mo.getVarName());
rddhandle.setHDFSFile(fromFile);
@@ -396,43 +390,43 @@ public class SparkExecutionContext extends ExecutionContext
else {
throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable");
}
-
+
//keep rdd handle for future operations on it
RDDObject rddhandle = new RDDObject(rdd, mo.getVarName());
rddhandle.setHDFSFile(true);
mo.setRDDHandle(rddhandle);
}
-
+
return rdd;
}
-
+
/**
* FIXME: currently this implementation assumes matrix representations but frame signature
* in order to support the old transform implementation.
- *
+ *
* @param fo frame object
* @param inputInfo input info
* @return JavaPairRDD handle for a frame object
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public JavaPairRDD<?,?> getRDDHandleForFrameObject( FrameObject fo, InputInfo inputInfo )
+ public JavaPairRDD<?,?> getRDDHandleForFrameObject( FrameObject fo, InputInfo inputInfo )
throws DMLRuntimeException
- {
+ {
//NOTE: MB this logic should be integrated into FrameObject
- //However, for now we cannot assume that spark libraries are
- //always available and hence only store generic references in
+ //However, for now we cannot assume that spark libraries are
+ //always available and hence only store generic references in
//matrix object while all the logic is in the SparkExecContext
-
- InputInfo inputInfo2 = (inputInfo==InputInfo.BinaryBlockInputInfo) ?
+
+ InputInfo inputInfo2 = (inputInfo==InputInfo.BinaryBlockInputInfo) ?
InputInfo.BinaryBlockFrameInputInfo : inputInfo;
-
+
JavaSparkContext sc = getSparkContext();
JavaPairRDD<?,?> rdd = null;
//CASE 1: rdd already existing (reuse if checkpoint or trigger
- //pending rdd operations if not yet cached but prevent to re-evaluate
+ //pending rdd operations if not yet cached but prevent to re-evaluate
//rdd operations if already executed and cached
- if( fo.getRDDHandle()!=null
+ if( fo.getRDDHandle()!=null
&& (fo.getRDDHandle().isCheckpointRDD() || !fo.isCached(false)) )
{
//return existing rdd handling (w/o input format change)
@@ -451,7 +445,7 @@ public class SparkExecutionContext extends ExecutionContext
fo.exportData();
}
rdd = sc.hadoopFile( fo.getFileName(), inputInfo2.inputFormatClass, inputInfo2.inputKeyClass, inputInfo2.inputValueClass);
- rdd = ((JavaPairRDD<LongWritable, FrameBlock>)rdd).mapToPair( new CopyFrameBlockPairFunction() ); //cp is workaround for read bug
+ rdd = ((JavaPairRDD<LongWritable, FrameBlock>)rdd).mapToPair( new CopyFrameBlockPairFunction() ); //cp is workaround for read bug
fromFile = true;
}
else { //default case
@@ -460,7 +454,7 @@ public class SparkExecutionContext extends ExecutionContext
fo.release(); //unpin frame
_parRDDs.registerRDD(rdd.id(), OptimizerUtils.estimatePartitionedSizeExactSparsity(mc), true);
}
-
+
//keep rdd handle for future operations on it
RDDObject rddhandle = new RDDObject(rdd, fo.getVarName());
rddhandle.setHDFSFile(fromFile);
@@ -487,64 +481,64 @@ public class SparkExecutionContext extends ExecutionContext
else {
throw new DMLRuntimeException("Incorrect input format in getRDDHandleForVariable");
}
-
+
//keep rdd handle for future operations on it
RDDObject rddhandle = new RDDObject(rdd, fo.getVarName());
rddhandle.setHDFSFile(true);
fo.setRDDHandle(rddhandle);
}
-
+
return rdd;
}
-
+
/**
* TODO So far we only create broadcast variables but never destroy
* them. This is a memory leak which might lead to executor out-of-memory.
- * However, in order to handle this, we need to keep track when broadcast
+ * However, in order to handle this, we need to keep track when broadcast
* variables are no longer required.
- *
+ *
* @param varname variable name
* @return wrapper for broadcast variables
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable( String varname )
+ public PartitionedBroadcast<MatrixBlock> getBroadcastForVariable( String varname )
throws DMLRuntimeException
- {
+ {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
MatrixObject mo = getMatrixObject(varname);
-
+
PartitionedBroadcast<MatrixBlock> bret = null;
-
+
//reuse existing broadcast handle
- if( mo.getBroadcastHandle()!=null
- && mo.getBroadcastHandle().isValid() )
+ if( mo.getBroadcastHandle()!=null
+ && mo.getBroadcastHandle().isValid() )
{
bret = mo.getBroadcastHandle().getBroadcast();
}
-
+
//create new broadcast handle (never created, evicted)
- if( bret == null )
+ if( bret == null )
{
//account for overwritten invalid broadcast (e.g., evicted)
if( mo.getBroadcastHandle()!=null )
CacheableData.addBroadcastSize(-mo.getBroadcastHandle().getSize());
-
- //obtain meta data for matrix
+
+ //obtain meta data for matrix
int brlen = (int) mo.getNumRowsPerBlock();
int bclen = (int) mo.getNumColumnsPerBlock();
-
+
//create partitioned matrix block and release memory consumed by input
MatrixBlock mb = mo.acquireRead();
PartitionedBlock<MatrixBlock> pmb = new PartitionedBlock<MatrixBlock>(mb, brlen, bclen);
mo.release();
-
+
//determine coarse-grained partitioning
int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(mo.getNumRows(), mo.getNumColumns(), brlen, bclen);
- int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
+ int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
Broadcast<PartitionedBlock<MatrixBlock>>[] ret = new Broadcast[numParts];
-
+
//create coarse-grained partitioned broadcasts
if( numParts > 1 ) {
for( int i=0; i<numParts; i++ ) {
@@ -557,60 +551,60 @@ public class SparkExecutionContext extends ExecutionContext
else { //single partition
ret[0] = getSparkContext().broadcast(pmb);
}
-
+
bret = new PartitionedBroadcast<MatrixBlock>(ret);
- BroadcastObject<MatrixBlock> bchandle = new BroadcastObject<MatrixBlock>(bret, varname,
+ BroadcastObject<MatrixBlock> bchandle = new BroadcastObject<MatrixBlock>(bret, varname,
OptimizerUtils.estimatePartitionedSizeExactSparsity(mo.getMatrixCharacteristics()));
mo.setBroadcastHandle(bchandle);
CacheableData.addBroadcastSize(bchandle.getSize());
}
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
Statistics.incSparkBroadcastCount(1);
}
-
+
return bret;
}
-
+
@SuppressWarnings("unchecked")
- public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable( String varname)
+ public PartitionedBroadcast<FrameBlock> getBroadcastForFrameVariable( String varname)
throws DMLRuntimeException
- {
+ {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
FrameObject fo = getFrameObject(varname);
-
+
PartitionedBroadcast<FrameBlock> bret = null;
-
+
//reuse existing broadcast handle
- if( fo.getBroadcastHandle()!=null
- && fo.getBroadcastHandle().isValid() )
+ if( fo.getBroadcastHandle()!=null
+ && fo.getBroadcastHandle().isValid() )
{
bret = fo.getBroadcastHandle().getBroadcast();
}
-
+
//create new broadcast handle (never created, evicted)
- if( bret == null )
+ if( bret == null )
{
//account for overwritten invalid broadcast (e.g., evicted)
if( fo.getBroadcastHandle()!=null )
CacheableData.addBroadcastSize(-fo.getBroadcastHandle().getSize());
-
- //obtain meta data for frame
+
+ //obtain meta data for frame
int bclen = (int) fo.getNumColumns();
int brlen = OptimizerUtils.getDefaultFrameSize();
-
+
//create partitioned frame block and release memory consumed by input
FrameBlock mb = fo.acquireRead();
PartitionedBlock<FrameBlock> pmb = new PartitionedBlock<FrameBlock>(mb, brlen, bclen);
fo.release();
-
+
//determine coarse-grained partitioning
int numPerPart = PartitionedBroadcast.computeBlocksPerPartition(fo.getNumRows(), fo.getNumColumns(), brlen, bclen);
- int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
+ int numParts = (int) Math.ceil((double)pmb.getNumRowBlocks()*pmb.getNumColumnBlocks() / numPerPart);
Broadcast<PartitionedBlock<FrameBlock>>[] ret = new Broadcast[numParts];
-
+
//create coarse-grained partitioned broadcasts
if( numParts > 1 ) {
for( int i=0; i<numParts; i++ ) {
@@ -623,41 +617,41 @@ public class SparkExecutionContext extends ExecutionContext
else { //single partition
ret[0] = getSparkContext().broadcast(pmb);
}
-
+
bret = new PartitionedBroadcast<FrameBlock>(ret);
- BroadcastObject<FrameBlock> bchandle = new BroadcastObject<FrameBlock>(bret, varname,
+ BroadcastObject<FrameBlock> bchandle = new BroadcastObject<FrameBlock>(bret, varname,
OptimizerUtils.estimatePartitionedSizeExactSparsity(fo.getMatrixCharacteristics()));
fo.setBroadcastHandle(bchandle);
CacheableData.addBroadcastSize(bchandle.getSize());
}
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkBroadCastTime(System.nanoTime() - t0);
Statistics.incSparkBroadcastCount(1);
}
-
+
return bret;
}
/**
- * Keep the output rdd of spark rdd operations as meta data of matrix/frame
+ * Keep the output rdd of spark rdd operations as meta data of matrix/frame
* objects in the symbol table.
- *
+ *
* @param varname variable name
* @param rdd JavaPairRDD handle for variable
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public void setRDDHandleForVariable(String varname, JavaPairRDD<?,?> rdd)
+ public void setRDDHandleForVariable(String varname, JavaPairRDD<?,?> rdd)
throws DMLRuntimeException
{
CacheableData<?> obj = getCacheableData(varname);
RDDObject rddhandle = new RDDObject(rdd, varname);
obj.setRDDHandle( rddhandle );
}
-
+
/**
* Utility method for creating an RDD out of an in-memory matrix block.
- *
+ *
* @param sc java spark context
* @param src matrix block
* @param brlen block row length
@@ -665,13 +659,13 @@ public class SparkExecutionContext extends ExecutionContext
* @return JavaPairRDD handle to matrix block
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen)
+ public static JavaPairRDD<MatrixIndexes,MatrixBlock> toMatrixJavaPairRDD(JavaSparkContext sc, MatrixBlock src, int brlen, int bclen)
throws DMLRuntimeException
- {
+ {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
LinkedList<Tuple2<MatrixIndexes,MatrixBlock>> list = new LinkedList<Tuple2<MatrixIndexes,MatrixBlock>>();
-
- if( src.getNumRows() <= brlen
+
+ if( src.getNumRows() <= brlen
&& src.getNumColumns() <= bclen )
{
list.addLast(new Tuple2<MatrixIndexes,MatrixBlock>(new MatrixIndexes(1,1), src));
@@ -679,44 +673,44 @@ public class SparkExecutionContext extends ExecutionContext
else
{
boolean sparse = src.isInSparseFormat();
-
+
//create and write subblocks of matrix
for(int blockRow = 0; blockRow < (int)Math.ceil(src.getNumRows()/(double)brlen); blockRow++)
for(int blockCol = 0; blockCol < (int)Math.ceil(src.getNumColumns()/(double)bclen); blockCol++)
{
int maxRow = (blockRow*brlen + brlen < src.getNumRows()) ? brlen : src.getNumRows() - blockRow*brlen;
int maxCol = (blockCol*bclen + bclen < src.getNumColumns()) ? bclen : src.getNumColumns() - blockCol*bclen;
-
+
MatrixBlock block = new MatrixBlock(maxRow, maxCol, sparse);
-
+
int row_offset = blockRow*brlen;
int col_offset = blockCol*bclen;
-
+
//copy submatrix to block
- src.sliceOperations( row_offset, row_offset+maxRow-1,
- col_offset, col_offset+maxCol-1, block );
-
+ src.sliceOperations( row_offset, row_offset+maxRow-1,
+ col_offset, col_offset+maxCol-1, block );
+
//append block to sequence file
MatrixIndexes indexes = new MatrixIndexes(blockRow+1, blockCol+1);
list.addLast(new Tuple2<MatrixIndexes,MatrixBlock>(indexes, block));
}
}
-
+
JavaPairRDD<MatrixIndexes,MatrixBlock> result = sc.parallelizePairs(list);
if (DMLScript.STATISTICS) {
Statistics.accSparkParallelizeTime(System.nanoTime() - t0);
Statistics.incSparkParallelizeCount(1);
}
-
+
return result;
}
- public static JavaPairRDD<Long,FrameBlock> toFrameJavaPairRDD(JavaSparkContext sc, FrameBlock src)
+ public static JavaPairRDD<Long,FrameBlock> toFrameJavaPairRDD(JavaSparkContext sc, FrameBlock src)
throws DMLRuntimeException
- {
+ {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
LinkedList<Tuple2<Long,FrameBlock>> list = new LinkedList<Tuple2<Long,FrameBlock>>();
-
+
//create and write subblocks of matrix
int blksize = ConfigurationManager.getBlocksize();
for(int blockRow = 0; blockRow < (int)Math.ceil(src.getNumRows()/(double)blksize); blockRow++)
@@ -725,28 +719,28 @@ public class SparkExecutionContext extends ExecutionContext
int roffset = blockRow*blksize;
FrameBlock block = new FrameBlock(src.getSchema());
-
+
//copy sub frame to block, incl meta data on first
- src.sliceOperations( roffset, roffset+maxRow-1, 0, src.getNumColumns()-1, block );
+ src.sliceOperations( roffset, roffset+maxRow-1, 0, src.getNumColumns()-1, block );
if( roffset == 0 )
block.setColumnMetadata(src.getColumnMetadata());
-
+
//append block to sequence file
list.addLast(new Tuple2<Long,FrameBlock>((long)roffset+1, block));
}
-
+
JavaPairRDD<Long,FrameBlock> result = sc.parallelizePairs(list);
if (DMLScript.STATISTICS) {
Statistics.accSparkParallelizeTime(System.nanoTime() - t0);
Statistics.incSparkParallelizeCount(1);
}
-
+
return result;
}
-
+
/**
* This method is a generic abstraction for calls from the buffer pool.
- *
+ *
* @param rdd rdd object
* @param rlen number of rows
* @param clen number of columns
@@ -757,21 +751,21 @@ public class SparkExecutionContext extends ExecutionContext
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
@SuppressWarnings("unchecked")
- public static MatrixBlock toMatrixBlock(RDDObject rdd, int rlen, int clen, int brlen, int bclen, long nnz)
+ public static MatrixBlock toMatrixBlock(RDDObject rdd, int rlen, int clen, int brlen, int bclen, long nnz)
throws DMLRuntimeException
- {
+ {
return toMatrixBlock(
- (JavaPairRDD<MatrixIndexes, MatrixBlock>) rdd.getRDD(),
+ (JavaPairRDD<MatrixIndexes, MatrixBlock>) rdd.getRDD(),
rlen, clen, brlen, bclen, nnz);
}
-
+
/**
- * Utility method for creating a single matrix block out of a binary block RDD.
- * Note that this collect call might trigger execution of any pending transformations.
- *
+ * Utility method for creating a single matrix block out of a binary block RDD.
+ * Note that this collect call might trigger execution of any pending transformations.
+ *
* NOTE: This is an unguarded utility function, which requires memory for both the output matrix
* and its collected, blocked representation.
- *
+ *
* @param rdd JavaPairRDD for matrix block
* @param rlen number of rows
* @param clen number of columns
@@ -781,19 +775,19 @@ public class SparkExecutionContext extends ExecutionContext
* @return matrix block
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, int rlen, int clen, int brlen, int bclen, long nnz)
+ public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, int rlen, int clen, int brlen, int bclen, long nnz)
throws DMLRuntimeException
{
-
+
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
MatrixBlock out = null;
-
+
if( rlen <= brlen && clen <= bclen ) //SINGLE BLOCK
{
//special case without copy and nnz maintenance
List<Tuple2<MatrixIndexes,MatrixBlock>> list = rdd.collect();
-
+
if( list.size()>1 )
throw new DMLRuntimeException("Expecting no more than one result block.");
else if( list.size()==1 )
@@ -806,70 +800,70 @@ public class SparkExecutionContext extends ExecutionContext
//determine target sparse/dense representation
long lnnz = (nnz >= 0) ? nnz : (long)rlen * clen;
boolean sparse = MatrixBlock.evalSparseFormatInMemory(rlen, clen, lnnz);
-
+
//create output matrix block (w/ lazy allocation)
out = new MatrixBlock(rlen, clen, sparse, lnnz);
-
+
List<Tuple2<MatrixIndexes,MatrixBlock>> list = rdd.collect();
-
+
//copy blocks one-at-a-time into output matrix block
- long aNnz = 0;
+ long aNnz = 0;
for( Tuple2<MatrixIndexes,MatrixBlock> keyval : list )
{
//unpack index-block pair
MatrixIndexes ix = keyval._1();
MatrixBlock block = keyval._2();
-
+
//compute row/column block offsets
int row_offset = (int)(ix.getRowIndex()-1)*brlen;
int col_offset = (int)(ix.getColumnIndex()-1)*bclen;
int rows = block.getNumRows();
int cols = block.getNumColumns();
-
+
//append block
if( sparse ) { //SPARSE OUTPUT
//append block to sparse target in order to avoid shifting, where
//we use a shallow row copy in case of MCSR and single column blocks
- //note: this append requires, for multiple column blocks, a final sort
+ //note: this append requires, for multiple column blocks, a final sort
out.appendToSparse(block, row_offset, col_offset, clen>bclen);
}
else { //DENSE OUTPUT
- out.copy( row_offset, row_offset+rows-1,
- col_offset, col_offset+cols-1, block, false );
+ out.copy( row_offset, row_offset+rows-1,
+ col_offset, col_offset+cols-1, block, false );
}
-
+
//incremental maintenance nnz
aNnz += block.getNonZeros();
}
-
+
//post-processing output matrix
if( sparse && clen>bclen )
out.sortSparseRows();
out.setNonZeros(aNnz);
out.examSparsity();
}
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkCollectTime(System.nanoTime() - t0);
Statistics.incSparkCollectCount(1);
}
-
+
return out;
}
-
+
@SuppressWarnings("unchecked")
- public static MatrixBlock toMatrixBlock(RDDObject rdd, int rlen, int clen, long nnz)
+ public static MatrixBlock toMatrixBlock(RDDObject rdd, int rlen, int clen, long nnz)
throws DMLRuntimeException
- {
+ {
return toMatrixBlock(
- (JavaPairRDD<MatrixIndexes, MatrixCell>) rdd.getRDD(),
+ (JavaPairRDD<MatrixIndexes, MatrixCell>) rdd.getRDD(),
rlen, clen, nnz);
}
-
+
/**
- * Utility method for creating a single matrix block out of a binary cell RDD.
- * Note that this collect call might trigger execution of any pending transformations.
- *
+ * Utility method for creating a single matrix block out of a binary cell RDD.
+ * Note that this collect call might trigger execution of any pending transformations.
+ *
* @param rdd JavaPairRDD for matrix block
* @param rlen number of rows
* @param clen number of columns
@@ -877,57 +871,57 @@ public class SparkExecutionContext extends ExecutionContext
* @return matrix block
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes, MatrixCell> rdd, int rlen, int clen, long nnz)
+ public static MatrixBlock toMatrixBlock(JavaPairRDD<MatrixIndexes, MatrixCell> rdd, int rlen, int clen, long nnz)
throws DMLRuntimeException
- {
+ {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
MatrixBlock out = null;
-
+
//determine target sparse/dense representation
long lnnz = (nnz >= 0) ? nnz : (long)rlen * clen;
boolean sparse = MatrixBlock.evalSparseFormatInMemory(rlen, clen, lnnz);
-
+
//create output matrix block (w/ lazy allocation)
out = new MatrixBlock(rlen, clen, sparse);
-
+
List<Tuple2<MatrixIndexes,MatrixCell>> list = rdd.collect();
-
+
//copy blocks one-at-a-time into output matrix block
for( Tuple2<MatrixIndexes,MatrixCell> keyval : list )
{
//unpack index-block pair
MatrixIndexes ix = keyval._1();
MatrixCell cell = keyval._2();
-
+
//append cell to dense/sparse target in order to avoid shifting for sparse
//note: this append requires a final sort of sparse rows
out.appendValue((int)ix.getRowIndex()-1, (int)ix.getColumnIndex()-1, cell.getValue());
}
-
+
//post-processing output matrix
if( sparse )
out.sortSparseRows();
out.recomputeNonZeros();
out.examSparsity();
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkCollectTime(System.nanoTime() - t0);
Statistics.incSparkCollectCount(1);
}
-
+
return out;
}
- public static PartitionedBlock<MatrixBlock> toPartitionedMatrixBlock(JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, int rlen, int clen, int brlen, int bclen, long nnz)
+ public static PartitionedBlock<MatrixBlock> toPartitionedMatrixBlock(JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, int rlen, int clen, int brlen, int bclen, long nnz)
throws DMLRuntimeException
{
-
+
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
PartitionedBlock<MatrixBlock> out = new PartitionedBlock<MatrixBlock>(rlen, clen, brlen, bclen);
List<Tuple2<MatrixIndexes,MatrixBlock>> list = rdd.collect();
-
+
//copy blocks one-at-a-time into output matrix block
for( Tuple2<MatrixIndexes,MatrixBlock> keyval : list )
{
@@ -936,24 +930,24 @@ public class SparkExecutionContext extends ExecutionContext
MatrixBlock block = keyval._2();
out.setBlock((int)ix.getRowIndex(), (int)ix.getColumnIndex(), block);
}
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkCollectTime(System.nanoTime() - t0);
Statistics.incSparkCollectCount(1);
}
-
+
return out;
}
@SuppressWarnings("unchecked")
- public static FrameBlock toFrameBlock(RDDObject rdd, ValueType[] schema, int rlen, int clen)
- throws DMLRuntimeException
+ public static FrameBlock toFrameBlock(RDDObject rdd, ValueType[] schema, int rlen, int clen)
+ throws DMLRuntimeException
{
JavaPairRDD<Long,FrameBlock> lrdd = (JavaPairRDD<Long,FrameBlock>) rdd.getRDD();
return toFrameBlock(lrdd, schema, rlen, clen);
}
- public static FrameBlock toFrameBlock(JavaPairRDD<Long,FrameBlock> rdd, ValueType[] schema, int rlen, int clen)
+ public static FrameBlock toFrameBlock(JavaPairRDD<Long,FrameBlock> rdd, ValueType[] schema, int rlen, int clen)
throws DMLRuntimeException
{
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
@@ -964,16 +958,16 @@ public class SparkExecutionContext extends ExecutionContext
//create output frame block (w/ lazy allocation)
FrameBlock out = new FrameBlock(schema);
out.ensureAllocatedColumns(rlen);
-
+
List<Tuple2<Long,FrameBlock>> list = rdd.collect();
-
+
//copy blocks one-at-a-time into output matrix block
for( Tuple2<Long,FrameBlock> keyval : list )
{
//unpack index-block pair
int ix = (int)(keyval._1() - 1);
FrameBlock block = keyval._2();
-
+
//copy into output frame
out.copy( ix, ix+block.getNumRows()-1, 0, block.getNumColumns()-1, block );
if( ix == 0 ) {
@@ -981,12 +975,12 @@ public class SparkExecutionContext extends ExecutionContext
out.setColumnMetadata(block.getColumnMetadata());
}
}
-
+
if (DMLScript.STATISTICS) {
Statistics.accSparkCollectTime(System.nanoTime() - t0);
Statistics.incSparkCollectCount(1);
}
-
+
return out;
}
@@ -994,17 +988,17 @@ public class SparkExecutionContext extends ExecutionContext
public static long writeRDDtoHDFS( RDDObject rdd, String path, OutputInfo oinfo )
{
JavaPairRDD<MatrixIndexes,MatrixBlock> lrdd = (JavaPairRDD<MatrixIndexes, MatrixBlock>) rdd.getRDD();
-
+
//piggyback nnz maintenance on write
LongAccumulator aNnz = getSparkContextStatic().sc().longAccumulator("nnz");
lrdd = lrdd.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
-
+
//save file is an action which also triggers nnz maintenance
- lrdd.saveAsHadoopFile(path,
- oinfo.outputKeyClass,
- oinfo.outputValueClass,
+ lrdd.saveAsHadoopFile(path,
+ oinfo.outputKeyClass,
+ oinfo.outputValueClass,
oinfo.outputFormatClass);
-
+
//return nnz aggregate of all blocks
return aNnz.value();
}
@@ -1013,58 +1007,58 @@ public class SparkExecutionContext extends ExecutionContext
public static void writeFrameRDDtoHDFS( RDDObject rdd, String path, OutputInfo oinfo )
{
JavaPairRDD<?, FrameBlock> lrdd = (JavaPairRDD<Long, FrameBlock>) rdd.getRDD();
-
+
//convert keys to writables if necessary
if( oinfo == OutputInfo.BinaryBlockOutputInfo ) {
lrdd = ((JavaPairRDD<Long, FrameBlock>)lrdd).mapToPair(
new LongFrameToLongWritableFrameFunction());
oinfo = OutputInfo.BinaryBlockFrameOutputInfo;
}
-
+
//save file is an action which also triggers nnz maintenance
- lrdd.saveAsHadoopFile(path,
- oinfo.outputKeyClass,
- oinfo.outputValueClass,
+ lrdd.saveAsHadoopFile(path,
+ oinfo.outputKeyClass,
+ oinfo.outputValueClass,
oinfo.outputFormatClass);
}
-
+
///////////////////////////////////////////
// Cleanup of RDDs and Broadcast variables
///////
-
+
/**
* Adds a child rdd object to the lineage of a parent rdd.
- *
+ *
* @param varParent parent variable
* @param varChild child variable
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public void addLineageRDD(String varParent, String varChild)
- throws DMLRuntimeException
+ public void addLineageRDD(String varParent, String varChild)
+ throws DMLRuntimeException
{
RDDObject parent = getCacheableData(varParent).getRDDHandle();
RDDObject child = getCacheableData(varChild).getRDDHandle();
-
+
parent.addLineageChild( child );
}
-
+
/**
* Adds a child broadcast object to the lineage of a parent rdd.
- *
+ *
* @param varParent parent variable
* @param varChild child variable
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
- public void addLineageBroadcast(String varParent, String varChild)
- throws DMLRuntimeException
+ public void addLineageBroadcast(String varParent, String varChild)
+ throws DMLRuntimeException
{
RDDObject parent = getCacheableData(varParent).getRDDHandle();
BroadcastObject<?> child = getCacheableData(varChild).getBroadcastHandle();
-
+
parent.addLineageChild( child );
}
- public void addLineage(String varParent, String varChild, boolean broadcast)
+ public void addLineage(String varParent, String varChild, boolean broadcast)
throws DMLRuntimeException
{
if( broadcast )
@@ -1072,25 +1066,25 @@ public class SparkExecutionContext extends ExecutionContext
else
addLineageRDD(varParent, varChild);
}
-
+
@Override
- public void cleanupMatrixObject( MatrixObject mo )
+ public void cleanupMatrixObject( MatrixObject mo )
throws DMLRuntimeException
{
//NOTE: this method overwrites the default behavior of cleanupMatrixObject
//and hence is transparently used by rmvar instructions and other users. The
//core difference is the lineage-based cleanup of RDD and broadcast variables.
-
+
try
{
- if ( mo.isCleanupEnabled() )
+ if ( mo.isCleanupEnabled() )
{
//compute ref count only if matrix cleanup actually necessary
- if ( !getVariables().hasReferences(mo) )
+ if ( !getVariables().hasReferences(mo) )
{
- //clean cached data
- mo.clearData();
-
+ //clean cached data
+ mo.clearData();
+
//clean hdfs data if no pending rdd operations on it
if( mo.isHDFSFileExists() && mo.getFileName()!=null ) {
if( mo.getRDDHandle()==null ) {
@@ -1101,12 +1095,12 @@ public class SparkExecutionContext extends ExecutionContext
rdd.setHDFSFilename(mo.getFileName());
}
}
-
+
//cleanup RDD and broadcast variables (recursive)
//note: requires that mo.clearData already removed back references
- if( mo.getRDDHandle()!=null ) {
+ if( mo.getRDDHandle()!=null ) {
rCleanupLineageObject(mo.getRDDHandle());
- }
+ }
if( mo.getBroadcastHandle()!=null ) {
rCleanupLineageObject(mo.getBroadcastHandle());
}
@@ -1120,18 +1114,18 @@ public class SparkExecutionContext extends ExecutionContext
}
@SuppressWarnings({ "rawtypes", "unchecked" })
- private void rCleanupLineageObject(LineageObject lob)
+ private void rCleanupLineageObject(LineageObject lob)
throws IOException
- {
+ {
//abort recursive cleanup if still consumers
if( lob.getNumReferences() > 0 )
return;
-
- //abort if still reachable through matrix object (via back references for
+
+ //abort if still reachable through matrix object (via back references for
//robustness in function calls and to prevent repeated scans of the symbol table)
if( lob.hasBackReference() )
return;
-
+
//cleanup current lineage object (from driver/executors)
//incl deferred hdfs file removal (only if metadata set by cleanup call)
if( lob instanceof RDDObject ) {
@@ -1151,38 +1145,38 @@ public class SparkExecutionContext extends ExecutionContext
cleanupBroadcastVariable(bc);
CacheableData.addBroadcastSize(-((BroadcastObject)lob).getSize());
}
-
+
//recursively process lineage children
for( LineageObject c : lob.getLineageChilds() ){
c.decrementNumReferences();
rCleanupLineageObject(c);
}
}
-
+
/**
* This call destroys a broadcast variable at all executors and the driver.
* Hence, it is intended to be used on rmvar only. Depending on the
* ASYNCHRONOUS_VAR_DESTROY configuration, this is asynchronous or not.
- *
+ *
* @param bvar broadcast variable
*/
- public static void cleanupBroadcastVariable(Broadcast<?> bvar)
+ public static void cleanupBroadcastVariable(Broadcast<?> bvar)
{
- //In comparison to 'unpersist' (which would only delete the broadcast
+ //In comparison to 'unpersist' (which would only delete the broadcast
//from the executors), this call also deletes related data from the driver.
if( bvar.isValid() ) {
bvar.destroy( !ASYNCHRONOUS_VAR_DESTROY );
}
}
-
+
/**
* This call removes an rdd variable from executor memory and disk if required.
* Hence, it is intended to be used on rmvar only. Depending on the
* ASYNCHRONOUS_VAR_DESTROY configuration, this is asynchronous or not.
- *
+ *
* @param rvar rdd variable to remove
*/
- public static void cleanupRDDVariable(JavaPairRDD<?,?> rvar)
+ public static void cleanupRDDVariable(JavaPairRDD<?,?> rvar)
{
if( rvar.getStorageLevel()!=StorageLevel.NONE() ) {
rvar.unpersist( !ASYNCHRONOUS_VAR_DESTROY );
@@ -1190,72 +1184,72 @@ public class SparkExecutionContext extends ExecutionContext
}
@SuppressWarnings("unchecked")
- public void repartitionAndCacheMatrixObject( String var )
+ public void repartitionAndCacheMatrixObject( String var )
throws DMLRuntimeException
{
MatrixObject mo = getMatrixObject(var);
MatrixCharacteristics mcIn = mo.getMatrixCharacteristics();
-
+
//double check size to avoid unnecessary spark context creation
if( !OptimizerUtils.exceedsCachingThreshold(mo.getNumColumns(), (double)
OptimizerUtils.estimateSizeExactSparsity(mcIn)) )
- return;
-
+ return;
+
//get input rdd and default storage level
- JavaPairRDD<MatrixIndexes,MatrixBlock> in = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
+ JavaPairRDD<MatrixIndexes,MatrixBlock> in = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
getRDDHandleForMatrixObject(mo, InputInfo.BinaryBlockInputInfo);
-
+
//avoid unnecessary caching of input in order to reduce memory pressure
if( mo.getRDDHandle().allowsShortCircuitRead()
&& isRDDMarkedForCaching(in.id()) && !isRDDCached(in.id()) ) {
in = (JavaPairRDD<MatrixIndexes,MatrixBlock>)
((RDDObject)mo.getRDDHandle().getLineageChilds().get(0)).getRDD();
-
+
//investigate issue of unnecessarily large number of partitions
int numPartitions = SparkUtils.getNumPreferredPartitions(mcIn, in);
if( numPartitions < in.getNumPartitions() )
in = in.coalesce( numPartitions );
}
-
- //repartition rdd (force creation of shuffled rdd via merge), note: without deep copy albeit
+
+ //repartition rdd (force creation of shuffled rdd via merge), note: without deep copy albeit
//executed on the original data, because there will be no merge, i.e., no key duplicates
JavaPairRDD<MatrixIndexes,MatrixBlock> out = RDDAggregateUtils.mergeByKey(in, false);
-
+
//convert mcsr into memory-efficient csr if potentially sparse
- if( OptimizerUtils.checkSparseBlockCSRConversion(mcIn) ) {
+ if( OptimizerUtils.checkSparseBlockCSRConversion(mcIn) ) {
out = out.mapValues(new CreateSparseBlockFunction(SparseBlock.Type.CSR));
}
-
- //persist rdd in default storage level
+
+ //persist rdd in default storage level
out.persist( Checkpoint.DEFAULT_STORAGE_LEVEL )
.count(); //trigger caching to prevent contention
-
+
//create new rdd handle, in-place of current matrix object
RDDObject inro = mo.getRDDHandle(); //guaranteed to exist (see above)
RDDObject outro = new RDDObject(out, var); //create new rdd object
outro.setCheckpointRDD(true); //mark as checkpointed
outro.addLineageChild(inro); //keep lineage to prevent cycles on cleanup
- mo.setRDDHandle(outro);
+ mo.setRDDHandle(outro);
}
@SuppressWarnings("unchecked")
- public void cacheMatrixObject( String var )
+ public void cacheMatrixObject( String var )
throws DMLRuntimeException
{
//get input rdd and default storage level
MatrixObject mo = getMatrixObject(var);
-
+
//double check size to avoid unnecessary spark context creation
if( !OptimizerUtils.exceedsCachingThreshold(mo.getNumColumns(), (double)
OptimizerUtils.estimateSizeExactSparsity(mo.getMatrixCharacteristics())) )
- return;
-
- JavaPairRDD<MatrixIndexes,MatrixBlock> in = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
+ return;
+
+ JavaPairRDD<MatrixIndexes,MatrixBlock> in = (JavaPairRDD<MatrixIndexes, MatrixBlock>)
getRDDHandleForMatrixObject(mo, InputInfo.BinaryBlockInputInfo);
-
+
//persist rdd (force rdd caching, if not already cached)
if( !isRDDCached(in.id()) )
- in.count(); //trigger caching to prevent contention
+ in.count(); //trigger caching to prevent contention
}
public void setThreadLocalSchedulerPool(String poolName) {
@@ -1283,7 +1277,7 @@ public class SparkExecutionContext extends ExecutionContext
if( !jsc.sc().getPersistentRDDs().contains(rddID) ) {
return false;
}
-
+
//check that rdd is actually already cached
for( RDDInfo info : jsc.sc().getRDDStorageInfo() ) {
if( info.id() == rddID )
@@ -1293,35 +1287,35 @@ public class SparkExecutionContext extends ExecutionContext
}
///////////////////////////////////////////
- // Spark configuration handling
+ // Spark configuration handling
///////
/**
- * Obtains the lazily analyzed spark cluster configuration.
- *
+ * Obtains the lazily analyzed spark cluster configuration.
+ *
* @return spark cluster configuration
*/
public static SparkClusterConfig getSparkClusterConfig() {
- //lazy creation of spark cluster config
+ //lazy creation of spark cluster config
if( _sconf == null )
_sconf = new SparkClusterConfig();
return _sconf;
}
-
+
/**
* Obtains the available memory budget for broadcast variables in bytes.
- *
+ *
* @return broadcast memory budget
*/
public static double getBroadcastMemoryBudget() {
return getSparkClusterConfig()
.getBroadcastMemoryBudget();
}
-
+
/**
* Obtain the available memory budget for data storage in bytes.
- *
- * @param min flag for minimum data budget
+ *
+ * @param min flag for minimum data budget
* @param refresh flag for refresh with spark context
* @return data memory budget
*/
@@ -1329,21 +1323,21 @@ public class SparkExecutionContext extends ExecutionContext
return getSparkClusterConfig()
.getDataMemoryBudget(min, refresh);
}
-
+
/**
* Obtain the number of executors in the cluster (excluding the driver).
- *
+ *
* @return number of executors
*/
public static int getNumExecutors() {
return getSparkClusterConfig()
.getNumExecutors();
}
-
+
/**
- * Obtain the default degree of parallelism (cores in the cluster).
- *
- * @param refresh flag for refresh with spark context
+ * Obtain the default degree of parallelism (cores in the cluster).
+ *
+ * @param refresh flag for refresh with spark context
* @return default degree of parallelism
*/
public static int getDefaultParallelism(boolean refresh) {
@@ -1360,13 +1354,13 @@ public class SparkExecutionContext extends ExecutionContext
int numExecutors = getNumExecutors();
int numCores = getDefaultParallelism(false);
boolean multiThreaded = (numCores > numExecutors);
-
+
//check for jdk version less than 8 (and raise warning if multi-threaded)
- if( isLtJDK8 && multiThreaded)
+ if( isLtJDK8 && multiThreaded)
{
- //get the jre version
+ //get the jre version
String version = System.getProperty("java.version");
-
+
LOG.warn("########################################################################################");
LOG.warn("### WARNING: Multi-threaded text reblock may lead to thread contention on JRE < 1.8 ####");
LOG.warn("### java.version = " + version);
@@ -1377,51 +1371,51 @@ public class SparkExecutionContext extends ExecutionContext
LOG.warn("########################################################################################");
}
}
-
+
/**
- * Captures relevant spark cluster configuration properties, e.g., memory budgets and
+ * Captures relevant spark cluster configuration properties, e.g., memory budgets and
* degree of parallelism. This configuration abstracts legacy (< Spark 1.6) and current
- * configurations and provides a unified view.
+ * configurations and provides a unified view.
*/
- private static class SparkClusterConfig
+ private static class SparkClusterConfig
{
//broadcasts are stored in mem-and-disk in data space, this config
//defines the fraction of data space to be used as broadcast budget
private static final double BROADCAST_DATA_FRACTION = 0.3;
-
+
//forward private config from Spark's UnifiedMemoryManager.scala (>1.6)
private static final long RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024;
-
+
//meta configurations
private boolean _legacyVersion = false; //spark version <1.6
private boolean _confOnly = false; //infrastructure info based on config
-
+
//memory management configurations
private long _memExecutor = -1; //mem per executor
private double _memDataMinFrac = -1; //minimum data fraction
private double _memDataMaxFrac = -1; //maximum data fraction
private double _memBroadcastFrac = -1; //broadcast fraction
-
+
//degree of parallelism configurations
private int _numExecutors = -1; //total executors
- private int _defaultPar = -1; //total vcores
-
- public SparkClusterConfig()
+ private int _defaultPar = -1; //total vcores
+
+ public SparkClusterConfig()
{
SparkConf sconf = createSystemMLSparkConf();
_confOnly = true;
-
+
//parse version and config
String sparkVersion = getSparkVersionString();
_legacyVersion = (UtilFunctions.compareVersion(sparkVersion, "1.6.0") < 0
|| sconf.getBoolean("spark.memory.useLegacyMode", false) );
-
+
//obtain basic spark configurations
if( _legacyVersion )
analyzeSparkConfiguationLegacy(sconf);
else
analyzeSparkConfiguation(sconf);
-
+
//log debug of created spark cluster config
if( LOG.isDebugEnabled() )
LOG.debug( this.toString() );
@@ -1432,30 +1426,30 @@ public class SparkExecutionContext extends ExecutionContext
}
public long getDataMemoryBudget(boolean min, boolean refresh) {
- //always get the current num executors on refresh because this might
+ //always get the current num executors on refresh because this might
//change if not all executors are initially allocated and it is plan-relevant
int numExec = _numExecutors;
if( refresh && !_confOnly ) {
JavaSparkContext jsc = getSparkContextStatic();
numExec = Math.max(jsc.sc().getExecutorMemoryStatus().size() - 1, 1);
}
-
+
//compute data memory budget
return (long) ( numExec * _memExecutor *
- (min ? _memDataMinFrac : _memDataMaxFrac) );
+ (min ? _memDataMinFrac : _memDataMaxFrac) );
}
public int getNumExecutors() {
if( _numExecutors < 0 )
- analyzeSparkParallelismConfiguation(null);
+ analyzeSparkParallelismConfiguation(null);
return _numExecutors;
}
public int getDefaultParallelism(boolean refresh) {
if( _defaultPar < 0 && !refresh )
analyzeSparkParallelismConfiguation(null);
-
- //always get the current default parallelism on refresh because this might
+
+ //always get the current default parallelism on refresh because this might
//change if not all executors are initially allocated and it is plan-relevant
return ( refresh && !_confOnly ) ?
getSparkContextStatic().defaultParallelism() : _defaultPar;
@@ -1464,36 +1458,36 @@ public class SparkExecutionContext extends ExecutionContext
public void analyzeSparkConfiguationLegacy(SparkConf conf) {
//ensure allocated spark conf
SparkConf sconf = (conf == null) ? createSystemMLSparkConf() : conf;
-
+
//parse absolute executor memory
_memExecutor = UtilFunctions.parseMemorySize(
sconf.get("spark.executor.memory", "1g"));
-
+
//get data and shuffle memory ratios (defaults not specified in job conf)
double dataFrac = sconf.getDouble("spark.storage.memoryFraction", 0.6); //default 60%
_memDataMinFrac = dataFrac;
_memDataMaxFrac = dataFrac;
_memBroadcastFrac = dataFrac * BROADCAST_DATA_FRACTION; //default 18%
-
- //analyze spark degree of parallelism
- analyzeSparkParallelismConfiguation(sconf);
+
+ //analyze spark degree of parallelism
+ analyzeSparkParallelismConfiguation(sconf);
}
public void analyzeSparkConfiguation(SparkConf conf) {
//ensure allocated spark conf
SparkConf sconf = (conf == null) ? createSystemMLSparkConf() : conf;
-
+
//parse absolute executor memory, incl fixed cut off
_memExecutor = UtilFunctions.parseMemorySize(
- sconf.get("spark.executor.memory", "1g"))
+ sconf.get("spark.executor.memory", "1g"))
- RESERVED_SYSTEM_MEMORY_BYTES;
-
+
//get data and shuffle memory ratios (defaults not specified in job conf)
_memDataMinFrac = sconf.getDouble("spark.memory.storageFraction", 0.5); //default 50%
_memDataMaxFrac = sconf.getDouble("spark.memory.fraction", 0.75); //default 75%
_memBroadcastFrac = _memDataMaxFrac * BROADCAST_DATA_FRACTION; //default 22.5%
-
- //analyze spark degree of parallelism
+
+ //analyze spark degree of parallelism
analyzeSparkParallelismConfiguation(sconf);
}
@@ -1501,7 +1495,7 @@ public class SparkExecutionContext extends ExecutionContext
int numExecutors = sconf.getInt("spark.executor.instances", -1);
int numCoresPerExec = sconf.getInt("spark.executor.cores", -1);
int defaultPar = sconf.getInt("spark.default.parallelism", -1);
-
+
if( numExecutors > 1 && (defaultPar > 1 || numCoresPerExec > 1) ) {
_numExecutors = numExecutors;
_defaultPar = (defaultPar>1) ? defaultPar : numExecutors * numCoresPerExec;
@@ -1512,28 +1506,28 @@ public class SparkExecutionContext extends ExecutionContext
//note: spark context provides this information while conf does not
//(for num executors we need to correct for driver and local mode)
JavaSparkContext jsc = getSparkContextStatic();
- _numExecutors = Math.max(jsc.sc().getExecutorMemoryStatus().size() - 1, 1);
+ _numExecutors = Math.max(jsc.sc().getExecutorMemoryStatus().size() - 1, 1);
_defaultPar = jsc.defaultParallelism();
- _confOnly &= false; //implies env info refresh w/ spark context
+ _confOnly &= false; //implies env info refresh w/ spark context
}
}
-
+
/**
* Obtains the spark version string. If the spark context has been created,
- * we simply get it from the context; otherwise, we use Spark internal
- * constants to avoid creating the spark context just for the version.
- *
+ * we simply get it from the context; otherwise, we use Spark internal
+ * constants to avoid creating the spark context just for the version.
+ *
* @return spark version string
*/
private String getSparkVersionString() {
//check for existing spark context
- if( isSparkContextCreated() )
+ if( isSparkContextCreated() )
return getSparkContextStatic().version();
-
+
//use spark internal constant to avoid context creation
return org.apache.spark.package$.MODULE$.SPARK_VERSION();
}
-
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder("SparkClusterConfig: \n");
@@ -1544,17 +1538,17 @@ public class SparkExecutionContext extends ExecutionContext
sb.append("-- memDataMaxFrac = " + _memDataMaxFrac + "\n");
sb.append("-- memBroadcastFrac = " + _memBroadcastFrac + "\n");
sb.append("-- numExecutors = " + _numExecutors + "\n");
- sb.append("-- defaultPar = " + _defaultPar + "\n");
+ sb.append("-- defaultPar = " + _defaultPar + "\n");
return sb.toString();
}
}
-
- private static class MemoryManagerParRDDs
+
+ private static class MemoryManagerParRDDs
{
private final long _limit;
private long _size;
private HashMap<Integer, Long> _rdds;
-
+
public MemoryManagerParRDDs(double fractionMem) {
_limit = (long)(fractionMem * InfrastructureAnalyzer.getLocalMaxMemory());
_size = 0;
@@ -1566,7 +1560,7 @@ public class SparkExecutionContext extends ExecutionContext
_size += ret ? rddSize : 0;
return ret;
}
-
+
public synchronized void registerRDD(int rddID, long rddSize, boolean reserved) {
if( !reserved ) {
throw new RuntimeException("Unsupported rdd registration "
@@ -1574,7 +1568,7 @@ public class SparkExecutionContext extends ExecutionContext
}
_rdds.put(rddID, rddSize);
}
-
+
public synchronized void deregisterRDD(int rddID) {
long rddSize = _rdds.remove(rddID);
_size -= rddSize;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/GetMLBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/GetMLBlock.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/GetMLBlock.java
deleted file mode 100644
index a1173bd..0000000
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/GetMLBlock.java
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.runtime.instructions.spark.functions;
-
-import java.io.Serializable;
-
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.Row;
-
-import scala.Tuple2;
-
-import org.apache.sysml.api.MLBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-
-@SuppressWarnings("deprecation")
-public class GetMLBlock implements Function<Tuple2<MatrixIndexes,MatrixBlock>, Row>, Serializable {
-
- private static final long serialVersionUID = 8829736765002126985L;
-
- @Override
- public Row call(Tuple2<MatrixIndexes, MatrixBlock> kv) throws Exception {
- return new MLBlock(kv._1, kv._2);
- }
-
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
index f206fbd..377ca2e 100644
--- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
+++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -63,14 +63,14 @@ import scala.Tuple2;
* can be moved to RDDConverterUtils.
*/
@SuppressWarnings("unused")
-public class RDDConverterUtilsExt
+public class RDDConverterUtilsExt
{
public enum RDDConverterTypes {
TEXT_TO_MATRIX_CELL,
MATRIXENTRY_TO_MATRIXCELL
}
-
-
+
+
/**
* Example usage:
* <pre><code>
@@ -88,7 +88,7 @@ public class RDDConverterUtilsExt
* val mc = new MatrixCharacteristics(numRows, numCols, 1000, 1000, nnz)
* val binBlocks = RDDConverterUtilsExt.coordinateMatrixToBinaryBlock(new JavaSparkContext(sc), coordinateMatrix, mc, true)
* </code></pre>
- *
+ *
* @param sc java spark context
* @param input coordinate matrix
* @param mcIn matrix characteristics
@@ -97,26 +97,26 @@ public class RDDConverterUtilsExt
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(JavaSparkContext sc,
- CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) throws DMLRuntimeException
+ CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) throws DMLRuntimeException
{
//convert matrix entry rdd to binary block rdd (w/ partial blocks)
JavaPairRDD<MatrixIndexes, MatrixBlock> out = input.entries().toJavaRDD()
.mapPartitionsToPair(new MatrixEntryToBinaryBlockFunction(mcIn));
-
- //inject empty blocks (if necessary)
+
+ //inject empty blocks (if necessary)
if( outputEmptyBlocks && mcIn.mightHaveEmptyBlocks() ) {
- out = out.union(
+ out = out.union(
SparkUtils.getEmptyBlockRDD(sc, mcIn) );
}
-
+
//aggregate partial matrix blocks
- out = RDDAggregateUtils.mergeByKey(out, false);
-
+ out = RDDAggregateUtils.mergeByKey(out, false);
+
return out;
}
-
+
public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBlock(SparkContext sc,
- CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) throws DMLRuntimeException
+ CoordinateMatrix input, MatrixCharacteristics mcIn, boolean outputEmptyBlocks) throws DMLRuntimeException
{
return coordinateMatrixToBinaryBlock(new JavaSparkContext(sc), input, mcIn, true);
}
@@ -128,19 +128,19 @@ public class RDDConverterUtilsExt
}
return df.select(columns.get(0), scala.collection.JavaConversions.asScalaBuffer(columnToSelect).toList());
}
-
+
public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen) throws DMLRuntimeException {
return convertPy4JArrayToMB(data, (int)rlen, (int)clen, false);
}
-
+
public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen) throws DMLRuntimeException {
return convertPy4JArrayToMB(data, rlen, clen, false);
}
-
+
public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, long rlen, long clen, long nnz) throws DMLRuntimeException {
return convertSciPyCOOToMB(data, row, col, (int)rlen, (int)clen, (int)nnz);
}
-
+
public static MatrixBlock convertSciPyCOOToMB(byte [] data, byte [] row, byte [] col, int rlen, int clen, int nnz) throws DMLRuntimeException {
MatrixBlock mb = new MatrixBlock(rlen, clen, true);
mb.allocateSparseRowsBlock(false);
@@ -154,17 +154,17 @@ public class RDDConverterUtilsExt
double val = buf1.getDouble();
int rowIndex = buf2.getInt();
int colIndex = buf3.getInt();
- mb.setValue(rowIndex, colIndex, val);
+ mb.setValue(rowIndex, colIndex, val);
}
mb.recomputeNonZeros();
mb.examSparsity();
return mb;
}
-
+
public static MatrixBlock convertPy4JArrayToMB(byte [] data, long rlen, long clen, boolean isSparse) throws DMLRuntimeException {
return convertPy4JArrayToMB(data, (int) rlen, (int) clen, isSparse);
}
-
+
public static MatrixBlock allocateDenseOrSparse(int rlen, int clen, boolean isSparse) {
MatrixBlock ret = new MatrixBlock(rlen, clen, isSparse);
ret.allocateDenseOrSparseBlock();
@@ -176,7 +176,7 @@ public class RDDConverterUtilsExt
}
return allocateDenseOrSparse(rlen, clen, isSparse);
}
-
+
public static void copyRowBlocks(MatrixBlock mb, int rowIndex, MatrixBlock ret, int numRowsPerBlock, int rlen, int clen) throws DMLRuntimeException {
copyRowBlocks(mb, (long)rowIndex, ret, (long)numRowsPerBlock, (long)rlen, (long)clen);
}
@@ -192,12 +192,12 @@ public class RDDConverterUtilsExt
ret.copy((int)(rowIndex*numRowsPerBlock), (int)Math.min((rowIndex+1)*numRowsPerBlock-1, rlen-1), 0, (int)(clen-1), mb, false);
// }
}
-
+
public static void postProcessAfterCopying(MatrixBlock ret) throws DMLRuntimeException {
ret.recomputeNonZeros();
ret.examSparsity();
}
-
+
public static MatrixBlock convertPy4JArrayToMB(byte [] data, int rlen, int clen, boolean isSparse) throws DMLRuntimeException {
MatrixBlock mb = new MatrixBlock(rlen, clen, isSparse, -1);
if(isSparse) {
@@ -219,19 +219,19 @@ public class RDDConverterUtilsExt
mb.examSparsity();
return mb;
}
-
+
public static byte [] convertMBtoPy4JDenseArr(MatrixBlock mb) throws DMLRuntimeException {
byte [] ret = null;
if(mb.isInSparseFormat()) {
mb.sparseToDense();
}
-
+
long limit = mb.getNumRows()*mb.getNumColumns();
int times = Double.SIZE / Byte.SIZE;
if( limit > Integer.MAX_VALUE / times )
throw new DMLRuntimeException("MatrixBlock of size " + limit + " cannot be converted to dense numpy array");
ret = new byte[(int) (limit * times)];
-
+
double [] denseBlock = mb.getDenseBlock();
if(mb.isEmptyBlock()) {
for(int i=0;i < limit;i++){
@@ -246,10 +246,10 @@ public class RDDConverterUtilsExt
ByteBuffer.wrap(ret, i*times, times).order(ByteOrder.nativeOrder()).putDouble(denseBlock[i]);
}
}
-
+
return ret;
}
-
+
public static class AddRowID implements Function<Tuple2<Row,Long>, Row> {
private static final long serialVersionUID = -3733816995375745659L;
@@ -263,12 +263,12 @@ public class RDDConverterUtilsExt
fields[oldNumCols] = new Double(arg0._2 + 1);
return RowFactory.create(fields);
}
-
+
}
/**
* Add element indices as new column to DataFrame
- *
+ *
* @param df input data frame
* @param sparkSession the Spark Session
* @param nameOfCol name of index column
@@ -286,27 +286,10 @@ public class RDDConverterUtilsExt
return sparkSession.createDataFrame(newRows, new StructType(newSchema));
}
- /**
- * Add element indices as new column to DataFrame
- *
- * @param df input data frame
- * @param sqlContext the SQL Context
- * @param nameOfCol name of index column
- * @return new data frame
- *
- * @deprecated This will be removed in SystemML 1.0.
- */
- @Deprecated
- public static Dataset<Row> addIDToDataFrame(Dataset<Row> df, SQLContext sqlContext, String nameOfCol) {
- SparkSession sparkSession = sqlContext.sparkSession();
- return addIDToDataFrame(df, sparkSession, nameOfCol);
- }
-
-
- private static class MatrixEntryToBinaryBlockFunction implements PairFlatMapFunction<Iterator<MatrixEntry>,MatrixIndexes,MatrixBlock>
+ private static class MatrixEntryToBinaryBlockFunction implements PairFlatMapFunction<Iterator<MatrixEntry>,MatrixIndexes,MatrixBlock>
{
private static final long serialVersionUID = 4907483236186747224L;
-
+
private IJVToBinaryBlockFunctionHelper helper = null;
public MatrixEntryToBinaryBlockFunction(MatrixCharacteristics mc) throws DMLRuntimeException {
helper = new IJVToBinaryBlockFunctionHelper(mc);
@@ -318,18 +301,18 @@ public class RDDConverterUtilsExt
}
}
-
+
private static class IJVToBinaryBlockFunctionHelper implements Serializable {
private static final long serialVersionUID = -7952801318564745821L;
//internal buffer size (aligned w/ default matrix block size)
private static final int BUFFER_SIZE = 4 * 1000 * 1000; //4M elements (32MB)
private int _bufflen = -1;
-
+
private long _rlen = -1;
private long _clen = -1;
private int _brlen = -1;
private int _bclen = -1;
-
+
public IJVToBinaryBlockFunctionHelper(MatrixCharacteristics mc) throws DMLRuntimeException
{
if(!mc.dimsKnown()) {
@@ -339,21 +322,21 @@ public class RDDConverterUtilsExt
_clen = mc.getCols();
_brlen = mc.getRowsPerBlock();
_bclen = mc.getColsPerBlock();
-
+
//determine upper bounded buffer len
_bufflen = (int) Math.min(_rlen*_clen, BUFFER_SIZE);
-
+
}
-
+
// ----------------------------------------------------
// Can extend this by having type hierarchy
public Tuple2<MatrixIndexes, MatrixCell> textToMatrixCell(Text txt) {
FastStringTokenizer st = new FastStringTokenizer(' ');
//get input string (ignore matrix market comments)
String strVal = txt.toString();
- if( strVal.startsWith("%") )
+ if( strVal.startsWith("%") )
return null;
-
+
//parse input ijv triple
st.reset( strVal );
long row = st.nextLong();
@@ -363,19 +346,19 @@ public class RDDConverterUtilsExt
MatrixCell cell = new MatrixCell(val);
return new Tuple2<MatrixIndexes, MatrixCell>(indx, cell);
}
-
+
public Tuple2<MatrixIndexes, MatrixCell> matrixEntryToMatrixCell(MatrixEntry entry) {
MatrixIndexes indx = new MatrixIndexes(entry.i(), entry.j());
MatrixCell cell = new MatrixCell(entry.value());
return new Tuple2<MatrixIndexes, MatrixCell>(indx, cell);
}
-
+
// ----------------------------------------------------
-
+
Iterable<Tuple2<MatrixIndexes, MatrixBlock>> convertToBinaryBlock(Object arg0, RDDConverterTypes converter) throws Exception {
ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret = new ArrayList<Tuple2<MatrixIndexes,MatrixBlock>>();
ReblockBuffer rbuff = new ReblockBuffer(_bufflen, _rlen, _clen, _brlen, _bclen);
-
+
Iterator<?> iter = (Iterator<?>) arg0;
while( iter.hasNext() ) {
Tuple2<MatrixIndexes, MatrixCell> cell = null;
@@ -383,38 +366,38 @@ public class RDDConverterUtilsExt
case MATRIXENTRY_TO_MATRIXCELL:
cell = matrixEntryToMatrixCell((MatrixEntry) iter.next());
break;
-
+
case TEXT_TO_MATRIX_CELL:
cell = textToMatrixCell((Text) iter.next());
break;
-
+
default:
throw new Exception("Invalid converter for IJV data:" + converter.toString());
}
-
+
if(cell == null) {
continue;
}
-
+
//flush buffer if necessary
if( rbuff.getSize() >= rbuff.getCapacity() )
flushBufferToList(rbuff, ret);
-
+
//add value to reblock buffer
rbuff.appendCell(cell._1.getRowIndex(), cell._1.getColumnIndex(), cell._2.getValue());
}
-
+
//final flush buffer
flushBufferToList(rbuff, ret);
-
+
return ret;
}
- private void flushBufferToList( ReblockBuffer rbuff, ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret )
+ private void flushBufferToList( ReblockBuffer rbuff, ArrayList<Tuple2<MatrixIndexes,MatrixBlock>> ret )
throws IOException, DMLRuntimeException
{
//temporary list of indexed matrix values to prevent library dependencies
- ArrayList<IndexedMatrixValue> rettmp = new ArrayList<IndexedMatrixValue>();
+ ArrayList<IndexedMatrixValue> rettmp = new ArrayList<IndexedMatrixValue>();
rbuff.flushBufferToBinaryBlocks(rettmp);
ret.addAll(SparkUtils.fromIndexedMatrixBlock(rettmp));
}
@@ -423,50 +406,17 @@ public class RDDConverterUtilsExt
/**
* Convert a dataframe of comma-separated string rows to a dataframe of
* ml.linalg.Vector rows.
- *
- * <p>
- * Example input rows:<br>
- *
- * <code>
- * ((1.2, 4.3, 3.4))<br>
- * (1.2, 3.4, 2.2)<br>
- * [[1.2, 34.3, 1.2, 1.25]]<br>
- * [1.2, 3.4]<br>
- * </code>
- *
- * @param sqlContext
- * Spark SQL Context
- * @param inputDF
- * dataframe of comma-separated row strings to convert to
- * dataframe of ml.linalg.Vector rows
- * @return dataframe of ml.linalg.Vector rows
- * @throws DMLRuntimeException
- * if DMLRuntimeException occurs
- *
- * @deprecated This will be removed in SystemML 1.0. Please migrate to {@code
- * RDDConverterUtilsExt.stringDataFrameToVectorDataFrame(SparkSession, Dataset<Row>) }
- */
- @Deprecated
- public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext sqlContext, Dataset<Row> inputDF)
- throws DMLRuntimeException {
- SparkSession sparkSession = sqlContext.sparkSession();
- return stringDataFrameToVectorDataFrame(sparkSession, inputDF);
- }
-
- /**
- * Convert a dataframe of comma-separated string rows to a dataframe of
- * ml.linalg.Vector rows.
- *
+ *
* <p>
* Example input rows:<br>
- *
+ *
* <code>
* ((1.2, 4.3, 3.4))<br>
* (1.2, 3.4, 2.2)<br>
* [[1.2, 34.3, 1.2, 1.25]]<br>
* [1.2, 3.4]<br>
* </code>
- *
+ *
* @param sparkSession
* Spark Session
* @param inputDF
[4/4] incubator-systemml git commit: [SYSTEMML-1303] Remove
deprecated old MLContext API
Posted by de...@apache.org.
[SYSTEMML-1303] Remove deprecated old MLContext API
Remove deprecated old MLContext API, scheduled to be removed in version 1.0.0.
Closes #511.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/7ba17c7f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/7ba17c7f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/7ba17c7f
Branch: refs/heads/master
Commit: 7ba17c7f6604171b6c569f33a3823cf660d536d9
Parents: 0a89676
Author: Deron Eriksson <de...@us.ibm.com>
Authored: Thu May 25 23:45:41 2017 -0700
Committer: Deron Eriksson <de...@us.ibm.com>
Committed: Thu May 25 23:45:41 2017 -0700
----------------------------------------------------------------------
src/main/java/org/apache/sysml/api/MLBlock.java | 280 ---
.../java/org/apache/sysml/api/MLContext.java | 1608 ------------------
.../org/apache/sysml/api/MLContextProxy.java | 50 +-
.../java/org/apache/sysml/api/MLMatrix.java | 428 -----
.../java/org/apache/sysml/api/MLOutput.java | 267 ---
.../org/apache/sysml/api/python/SystemML.py | 232 ---
.../context/SparkExecutionContext.java | 694 ++++----
.../spark/functions/GetMLBlock.java | 43 -
.../spark/utils/RDDConverterUtilsExt.java | 166 +-
.../test/integration/AutomatedTestBase.java | 433 +++--
10 files changed, 622 insertions(+), 3579 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/MLBlock.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLBlock.java b/src/main/java/org/apache/sysml/api/MLBlock.java
deleted file mode 100644
index 69dc5fc..0000000
--- a/src/main/java/org/apache/sysml/api/MLBlock.java
+++ /dev/null
@@ -1,280 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.api;
-
-import java.math.BigDecimal;
-import java.sql.Date;
-import java.sql.Timestamp;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.types.DataType;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-
-import scala.collection.JavaConversions;
-import scala.collection.Seq;
-import scala.collection.mutable.Buffer;
-
-/**
- * @deprecated This will be removed in SystemML 1.0. Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}
- */
-@Deprecated
-public class MLBlock implements Row {
-
- private static final long serialVersionUID = -770986277854643424L;
-
- public MatrixIndexes indexes;
- public MatrixBlock block;
-
- public MLBlock(MatrixIndexes indexes, MatrixBlock block) {
- this.indexes = indexes;
- this.block = block;
- }
-
- @Override
- public boolean anyNull() {
- // TODO
- return false;
- }
-
- @Override
- public Object apply(int arg0) {
- if(arg0 == 0) {
- return indexes;
- }
- else if(arg0 == 1) {
- return block;
- }
- // TODO: For now not supporting any operations
- return 0;
- }
-
- @Override
- public Row copy() {
- return new MLBlock(new MatrixIndexes(indexes), new MatrixBlock(block));
- }
-
- @Override
- public Object get(int arg0) {
- if(arg0 == 0) {
- return indexes;
- }
- else if(arg0 == 1) {
- return block;
- }
- // TODO: For now not supporting any operations
- return 0;
- }
-
- @Override
- public <T> T getAs(int arg0) {
- // TODO
- return null;
- }
-
- @Override
- public <T> T getAs(String arg0) {
- // TODO
- return null;
- }
-
- @Override
- public boolean getBoolean(int arg0) {
- // TODO
- return false;
- }
-
- @Override
- public byte getByte(int arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public Date getDate(int arg0) {
- // TODO
- return null;
- }
-
- @Override
- public BigDecimal getDecimal(int arg0) {
- // TODO
- return null;
- }
-
- @Override
- public double getDouble(int arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public float getFloat(int arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public int getInt(int arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public <K, V> Map<K, V> getJavaMap(int arg0) {
- return null;
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public <T> List<T> getList(int arg0) {
- ArrayList<Object> retVal = new ArrayList<Object>();
- retVal.add(indexes);
- retVal.add(block);
- //retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(indexes, block));
- return (List<T>) scala.collection.JavaConversions.asScalaBuffer(retVal).toList();
- }
-
- @Override
- public long getLong(int arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public int fieldIndex(String arg0) {
- // TODO
- return 0;
- }
-
- @Override
- public <K, V> scala.collection.Map<K, V> getMap(int arg0) {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public <T> scala.collection.immutable.Map<String, T> getValuesMap(Seq<String> arg0) {
- // TODO Auto-generated method stub
- return null;
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public <T> Seq<T> getSeq(int arg0) {
- ArrayList<Object> retVal = new ArrayList<Object>();
- retVal.add(indexes);
- retVal.add(block);
- // retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(indexes, block));
- @SuppressWarnings("rawtypes")
- Buffer scBuf = JavaConversions.asScalaBuffer(retVal);
- return scBuf.toSeq();
- }
-
- @Override
- public short getShort(int arg0) {
- // TODO Auto-generated method stub
- return 0;
- }
-
- @Override
- public String getString(int arg0) {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public Row getStruct(int arg0) {
- return this;
- }
-
- @Override
- public boolean isNullAt(int arg0) {
- // TODO Auto-generated method stub
- return false;
- }
-
- @Override
- public int length() {
- return 2;
- }
-
- @Override
- public String mkString() {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public String mkString(String arg0) {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public String mkString(String arg0, String arg1, String arg2) {
- // TODO Auto-generated method stub
- return null;
- }
-
- @Override
- public StructType schema() {
- return getDefaultSchemaForBinaryBlock();
- }
-
-
- @Override
- public int size() {
- return 2;
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public Seq<Object> toSeq() {
- ArrayList<Object> retVal = new ArrayList<Object>();
- retVal.add(indexes);
- retVal.add(block);
- // retVal.add(new Tuple2<MatrixIndexes, MatrixBlock>(indexes, block));
- @SuppressWarnings("rawtypes")
- Buffer scBuf = JavaConversions.asScalaBuffer(retVal);
- return scBuf.toSeq();
- }
-
- public static StructType getDefaultSchemaForBinaryBlock() {
- // TODO:
- StructField[] fields = new StructField[2];
- fields[0] = new StructField("IgnoreSchema", DataType.fromJson("DoubleType"), true, null);
- fields[1] = new StructField("IgnoreSchema1", DataType.fromJson("DoubleType"), true, null);
- return new StructType(fields);
- }
-
- // required for Spark 1.6+
- public Timestamp getTimestamp(int position) {
- // position 0 = MatrixIndexes and position 1 = MatrixBlock,
- // so return null since neither is of date type
- return null;
- }
-
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
deleted file mode 100644
index b3102e9..0000000
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ /dev/null
@@ -1,1608 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.api;
-
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Map.Entry;
-import java.util.Scanner;
-
-import org.apache.hadoop.io.LongWritable;
-import org.apache.hadoop.io.Text;
-import org.apache.spark.SparkContext;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
-import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
-import org.apache.sysml.api.jmlc.JMLCUtils;
-import org.apache.sysml.api.mlcontext.ScriptType;
-import org.apache.sysml.conf.CompilerConfig;
-import org.apache.sysml.conf.CompilerConfig.ConfigType;
-import org.apache.sysml.conf.ConfigurationManager;
-import org.apache.sysml.conf.DMLConfig;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.hops.OptimizerUtils.OptimizationLevel;
-import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper;
-import org.apache.sysml.hops.rewrite.ProgramRewriter;
-import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite;
-import org.apache.sysml.parser.DMLProgram;
-import org.apache.sysml.parser.DMLTranslator;
-import org.apache.sysml.parser.DataExpression;
-import org.apache.sysml.parser.Expression;
-import org.apache.sysml.parser.Expression.ValueType;
-import org.apache.sysml.parser.IntIdentifier;
-import org.apache.sysml.parser.LanguageException;
-import org.apache.sysml.parser.ParseException;
-import org.apache.sysml.parser.ParserFactory;
-import org.apache.sysml.parser.ParserWrapper;
-import org.apache.sysml.parser.StringIdentifier;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.LocalVariableMap;
-import org.apache.sysml.runtime.controlprogram.Program;
-import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
-import org.apache.sysml.runtime.controlprogram.caching.FrameObject;
-import org.apache.sysml.runtime.controlprogram.caching.MatrixObject;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
-import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.Instruction;
-import org.apache.sysml.runtime.instructions.cp.Data;
-import org.apache.sysml.runtime.instructions.spark.data.RDDObject;
-import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLongTextPair;
-import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction;
-import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
-import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
-import org.apache.sysml.runtime.instructions.spark.utils.SparkUtils;
-import org.apache.sysml.runtime.io.IOUtilFunctions;
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.MatrixFormatMetaData;
-import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
-import org.apache.sysml.runtime.matrix.data.FileFormatProperties;
-import org.apache.sysml.runtime.matrix.data.FrameBlock;
-import org.apache.sysml.runtime.matrix.data.InputInfo;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-import org.apache.sysml.runtime.matrix.data.OutputInfo;
-import org.apache.sysml.utils.Explain;
-import org.apache.sysml.utils.Explain.ExplainCounts;
-import org.apache.sysml.utils.Statistics;
-
-/**
- * MLContext is useful for passing RDDs as input/output to SystemML. This API avoids the need to read/write
- * from HDFS (which is another way to pass inputs to SystemML).
- * <p>
- * Typical usage for MLContext is as follows:
- * <pre><code>
- * scala> import org.apache.sysml.api.MLContext
- * </code></pre>
- * <p>
- * Create input DataFrame from CSV file and potentially perform some feature transformation
- * <pre><code>
- * scala> val W = sparkSession.load("com.databricks.spark.csv", Map("path" -> "W.csv", "header" -> "false"))
- * scala> val H = sparkSession.load("com.databricks.spark.csv", Map("path" -> "H.csv", "header" -> "false"))
- * scala> val V = sparkSession.load("com.databricks.spark.csv", Map("path" -> "V.csv", "header" -> "false"))
- * </code></pre>
- * <p>
- * Create MLContext
- * <pre><code>
- * scala> val ml = new MLContext(sc)
- * </code></pre>
- * <p>
- * Register input and output DataFrame/RDD
- * Supported format:
- * <ol>
- * <li> DataFrame
- * <li> CSV/Text (as JavaRDD<String> or JavaPairRDD<LongWritable, Text>)
- * <li> Binary blocked RDD (JavaPairRDD<MatrixIndexes,MatrixBlock>))
- * </ol>
- * Also overloaded to support metadata information such as format, rlen, clen, ...
- * Please note the variable names given below in quotes correspond to the variables in DML script.
- * These variables need to have corresponding read/write associated in DML script.
- * Currently, only matrix variables are supported through registerInput/registerOutput interface.
- * To pass scalar variables, use named/positional arguments (described later) or wrap them into matrix variable.
- * <pre><code>
- * scala> ml.registerInput("V", V)
- * scala> ml.registerInput("W", W)
- * scala> ml.registerInput("H", H)
- * scala> ml.registerOutput("H")
- * scala> ml.registerOutput("W")
- * </code></pre>
- * <p>
- * Call script with default arguments:
- * <pre><code>
- * scala> val outputs = ml.execute("GNMF.dml")
- * </code></pre>
- * <p>
- * Also supported: calling script with positional arguments (args) and named arguments (nargs):
- * <pre><code>
- * scala> val args = Array("V.mtx", "W.mtx", "H.mtx", "2000", "1500", "50", "1", "WOut.mtx", "HOut.mtx")
- * scala> val nargs = Map("maxIter"->"1", "V" -> "")
- * scala> val outputs = ml.execute("GNMF.dml", args) # or ml.execute("GNMF_namedArgs.dml", nargs)
- * </code></pre>
- * <p>
- * To run the script again using different (or even same arguments), but using same registered input/outputs:
- * <pre><code>
- * scala> val new_outputs = ml.execute("GNMF.dml", new_args)
- * </code></pre>
- * <p>
- * However, to register new input/outputs, you need to first reset MLContext
- * <pre><code>
- * scala> ml.reset()
- * scala> ml.registerInput("V", newV)
- * </code></pre>
- * <p>
- * Experimental API:
- * To monitor performance (only supported for Spark 1.4.0 or higher),
- * <pre><code>
- * scala> val ml = new MLContext(sc, true)
- * </code></pre>
- * <p>
- * If monitoring performance is enabled,
- * <pre><code>
- * scala> print(ml.getMonitoringUtil().getExplainOutput())
- * scala> ml.getMonitoringUtil().getRuntimeInfoInHTML("runtime.html")
- * </code></pre>
- * <p>
- * Note: The execute(...) methods does not support parallel calls from same or different MLContext.
- * This is because current SystemML engine does not allow multiple invocation in same JVM.
- * So, if you plan to create a system which potentially creates multiple MLContext,
- * it is recommended to guard the execute(...) call using
- * <pre><code>
- * synchronized(MLContext.class) { ml.execute(...); }
- * </code></pre>
- *
- * @deprecated This will be removed in SystemML 1.0. Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}
- */
-@Deprecated
-public class MLContext {
-
- // ----------------------------------------------------
- // TODO: To make MLContext multi-threaded, track getCurrentMLContext and also all singletons and
- // static variables in SystemML codebase.
- private static MLContext _activeMLContext = null;
-
- // Package protected so as to maintain a clean public API for MLContext.
- // Use MLContextProxy.getActiveMLContext() if necessary
- static MLContext getActiveMLContext() {
- return _activeMLContext;
- }
- // ----------------------------------------------------
-
- private SparkContext _sc = null; // Read while creating SystemML's spark context
- public SparkContext getSparkContext() {
- if(_sc == null) {
- throw new RuntimeException("No spark context set in MLContext");
- }
- return _sc;
- }
- private ArrayList<String> _inVarnames = null;
- private ArrayList<String> _outVarnames = null;
- private LocalVariableMap _variables = null; // temporary symbol table
- private Program _rtprog = null;
-
- private Map<String, String> _additionalConfigs = new HashMap<String, String>();
-
- /**
- * Create an associated MLContext for given spark session.
- * @param sc SparkContext
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public MLContext(SparkContext sc) throws DMLRuntimeException {
- initializeSpark(sc, false, false);
- }
-
- /**
- * Create an associated MLContext for given spark session.
- * @param sc JavaSparkContext
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public MLContext(JavaSparkContext sc) throws DMLRuntimeException {
- initializeSpark(sc.sc(), false, false);
- }
-
- /**
- * Allow users to provide custom named-value configuration.
- * @param paramName parameter name
- * @param paramVal parameter value
- */
- public void setConfig(String paramName, String paramVal) {
- _additionalConfigs.put(paramName, paramVal);
- }
-
- // ====================================================================================
- // Register input APIs
- // 1. DataFrame
-
- /**
- * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into double
- * through Double.parseDouble(cell.toString()). This is suitable for passing dense matrices. For sparse matrices,
- * consider passing through text format (using JavaRDD<String>, format="text")
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param df the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, Dataset<Row> df) throws DMLRuntimeException {
- registerInput(varName, df, false);
- }
-
- /**
- * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into
- * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean.
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param df the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerFrameInput(String varName, Dataset<Row> df) throws DMLRuntimeException {
- registerFrameInput(varName, df, false);
- }
-
- /**
- * Register DataFrame as input.
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param df the DataFrame
- * @param containsID false if the DataFrame has an column ID which denotes the row ID.
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, Dataset<Row> df, boolean containsID) throws DMLRuntimeException {
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mcOut = new MatrixCharacteristics(-1, -1, blksz, blksz);
- JavaPairRDD<MatrixIndexes, MatrixBlock> rdd = RDDConverterUtils
- .dataFrameToBinaryBlock(new JavaSparkContext(_sc), df, mcOut, containsID, false);
- registerInput(varName, rdd, mcOut);
- }
-
- /**
- * Register DataFrame as input. DataFrame is assumed to be in row format and each cell can be converted into
- * SystemML frame row. Each column could be of type, Double, Float, Long, Integer, String or Boolean.
- * <p>
- * @param varName variable name
- * @param df the DataFrame
- * @param containsID false if the DataFrame has an column ID which denotes the row ID.
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerFrameInput(String varName, Dataset<Row> df, boolean containsID) throws DMLRuntimeException {
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mcOut = new MatrixCharacteristics(-1, -1, blksz, blksz);
- JavaPairRDD<Long, FrameBlock> rdd = FrameRDDConverterUtils.dataFrameToBinaryBlock(new JavaSparkContext(_sc), df, mcOut, containsID);
- registerInput(varName, rdd, mcOut.getRows(), mcOut.getCols(), null);
- }
-
- /**
- * Experimental API. Not supported in Python MLContext API.
- * @param varName variable name
- * @param df the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, MLMatrix df) throws DMLRuntimeException {
- registerInput(varName, MLMatrix.getRDDLazily(df), df.mc);
- }
-
- // ------------------------------------------------------------------------------------
- // 2. CSV/Text: Usually JavaRDD<String>, but also supports JavaPairRDD<LongWritable, Text>
- /**
- * Register CSV/Text as inputs: Method for supplying csv file format properties, but without dimensions or nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the RDD
- * @param format the format
- * @param hasHeader is there a header
- * @param delim the delimiter
- * @param fill if true, fill, otherwise don't fill
- * @param fillValue the fill value
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rdd, String format, boolean hasHeader,
- String delim, boolean fill, double fillValue) throws DMLRuntimeException {
- registerInput(varName, rdd, format, hasHeader, delim, fill, fillValue, -1, -1, -1);
- }
-
- /**
- * Register CSV/Text as inputs: Method for supplying csv file format properties, but without dimensions or nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the RDD
- * @param format the format
- * @param hasHeader is there a header
- * @param delim the delimiter
- * @param fill if true, fill, otherwise don't fill
- * @param fillValue the fill value
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, RDD<String> rdd, String format, boolean hasHeader,
- String delim, boolean fill, double fillValue) throws DMLRuntimeException {
- registerInput(varName, rdd.toJavaRDD(), format, hasHeader, delim, fill, fillValue, -1, -1, -1);
- }
-
- /**
- * Register CSV/Text as inputs: Method for supplying csv file format properties along with dimensions or nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the RDD
- * @param format the format
- * @param hasHeader is there a header
- * @param delim the delimiter
- * @param fill if true, fill, otherwise don't fill
- * @param fillValue the fill value
- * @param rlen rows
- * @param clen columns
- * @param nnz non-zeros
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, RDD<String> rdd, String format, boolean hasHeader,
- String delim, boolean fill, double fillValue, long rlen, long clen, long nnz) throws DMLRuntimeException {
- registerInput(varName, rdd.toJavaRDD(), format, hasHeader, delim, fill, fillValue, -1, -1, -1);
- }
-
- /**
- * Register CSV/Text as inputs: Method for supplying csv file format properties along with dimensions or nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaRDD
- * @param format the format
- * @param hasHeader is there a header
- * @param delim the delimiter
- * @param fill if true, fill, otherwise don't fill
- * @param fillValue the fill value
- * @param rlen rows
- * @param clen columns
- * @param nnz non-zeros
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rdd, String format, boolean hasHeader,
- String delim, boolean fill, double fillValue, long rlen, long clen, long nnz) throws DMLRuntimeException {
- CSVFileFormatProperties props = new CSVFileFormatProperties(hasHeader, delim, fill, fillValue, "");
- registerInput(varName, rdd.mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, nnz, props);
- }
-
- /**
- * Register CSV/Text as inputs: Convenience method without dimensions and nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the RDD
- * @param format the format
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, RDD<String> rdd, String format) throws DMLRuntimeException {
- registerInput(varName, rdd.toJavaRDD().mapToPair(new ConvertStringToLongTextPair()), format, -1, -1, -1, null);
- }
-
- /**
- * Register CSV/Text as inputs: Convenience method without dimensions and nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaRDD
- * @param format the format
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rdd, String format) throws DMLRuntimeException {
- registerInput(varName, rdd.mapToPair(new ConvertStringToLongTextPair()), format, -1, -1, -1, null);
- }
-
- /**
- * Register CSV/Text as inputs: Convenience method with dimensions and but no nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaRDD
- * @param format the format
- * @param rlen rows
- * @param clen columns
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rdd, String format, long rlen, long clen) throws DMLRuntimeException {
- registerInput(varName, rdd.mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, -1, null);
- }
-
- /**
- * Register CSV/Text as inputs: Convenience method with dimensions and but no nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the RDD
- * @param format the format
- * @param rlen rows
- * @param clen columns
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, RDD<String> rdd, String format, long rlen, long clen) throws DMLRuntimeException {
- registerInput(varName, rdd.toJavaRDD().mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, -1, null);
- }
-
- /**
- * Register CSV/Text as inputs: with dimensions and nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaRDD
- * @param format the format
- * @param rlen rows
- * @param clen columns
- * @param nnz non-zeros
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rdd, String format, long rlen, long clen, long nnz) throws DMLRuntimeException {
- registerInput(varName, rdd.mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, nnz, null);
- }
-
- /**
- * Register CSV/Text as inputs: with dimensions and nnz. It uses default file properties (example: delim, fill, ..)
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaRDD
- * @param format the format
- * @param rlen rows
- * @param clen columns
- * @param nnz non-zeros
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, RDD<String> rdd, String format, long rlen, long clen, long nnz) throws DMLRuntimeException {
- registerInput(varName, rdd.toJavaRDD().mapToPair(new ConvertStringToLongTextPair()), format, rlen, clen, nnz, null);
- }
-
- // All CSV related methods call this ... It provides access to dimensions, nnz, file properties.
- private void registerInput(String varName, JavaPairRDD<LongWritable, Text> textOrCsv_rdd, String format, long rlen, long clen, long nnz, FileFormatProperties props) throws DMLRuntimeException {
- if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
- throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
- }
-
- if(_variables == null)
- _variables = new LocalVariableMap();
- if(_inVarnames == null)
- _inVarnames = new ArrayList<String>();
-
- MatrixObject mo;
- if( format.equals("csv") ) {
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, blksz, blksz, nnz);
- mo = new MatrixObject(ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo));
- }
- else if( format.equals("text") ) {
- if(rlen == -1 || clen == -1) {
- throw new DMLRuntimeException("The metadata is required in registerInput for format:" + format);
- }
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, blksz, blksz, nnz);
- mo = new MatrixObject(ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo));
- }
- else if( format.equals("mm") ) {
- // TODO: Handle matrix market
- throw new DMLRuntimeException("Matrixmarket format is not yet implemented in registerInput: " + format);
- }
- else {
-
- throw new DMLRuntimeException("Incorrect format in registerInput: " + format);
- }
-
- JavaPairRDD<LongWritable, Text> rdd = textOrCsv_rdd.mapToPair(new CopyTextInputFunction());
- if(props != null)
- mo.setFileFormatProperties(props);
- mo.setRDDHandle(new RDDObject(rdd, varName));
- _variables.put(varName, mo);
- _inVarnames.add(varName);
- checkIfRegisteringInputAllowed();
- }
-
- /**
- * Register Frame with CSV/Text as inputs: with dimensions.
- * File properties (example: delim, fill, ..) can be specified through props else defaults will be used.
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rddIn the JavaPairRDD
- * @param format the format
- * @param rlen rows
- * @param clen columns
- * @param props properties
- * @param schema List of column types
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaRDD<String> rddIn, String format, long rlen, long clen, FileFormatProperties props,
- List<ValueType> schema) throws DMLRuntimeException {
- if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
- throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
- }
-
- long nnz = -1;
- if(_variables == null)
- _variables = new LocalVariableMap();
- if(_inVarnames == null)
- _inVarnames = new ArrayList<String>();
-
- JavaPairRDD<LongWritable, Text> rddText = rddIn.mapToPair(new ConvertStringToLongTextPair());
-
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, blksz, blksz, nnz);
- FrameObject fo = null;
- if( format.equals("csv") ) {
- CSVFileFormatProperties csvprops = (props!=null) ? (CSVFileFormatProperties)props: new CSVFileFormatProperties();
- fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo));
- fo.setFileFormatProperties(csvprops);
- }
- else if( format.equals("text") ) {
- if(rlen == -1 || clen == -1) {
- throw new DMLRuntimeException("The metadata is required in registerInput for format:" + format);
- }
- fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo));
- }
- else {
-
- throw new DMLRuntimeException("Incorrect format in registerInput: " + format);
- }
- if(props != null)
- fo.setFileFormatProperties(props);
-
- fo.setRDDHandle(new RDDObject(rddText, varName));
- fo.setSchema("String"); //TODO fix schema
- _variables.put(varName, fo);
- _inVarnames.add(varName);
- checkIfRegisteringInputAllowed();
- }
-
- private void registerInput(String varName, JavaPairRDD<Long, FrameBlock> rdd, long rlen, long clen, FileFormatProperties props) throws DMLRuntimeException {
- if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
- throw new DMLRuntimeException("The registerInput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
- }
-
- if(_variables == null)
- _variables = new LocalVariableMap();
- if(_inVarnames == null)
- _inVarnames = new ArrayList<String>();
-
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, blksz, blksz, -1);
- FrameObject fo = new FrameObject(OptimizerUtils.getUniqueTempFileName(), new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
-
- if(props != null)
- fo.setFileFormatProperties(props);
-
- fo.setRDDHandle(new RDDObject(rdd, varName));
- _variables.put(varName, fo);
- _inVarnames.add(varName);
- checkIfRegisteringInputAllowed();
- }
-
- // ------------------------------------------------------------------------------------
-
- // 3. Binary blocked RDD: Support JavaPairRDD<MatrixIndexes,MatrixBlock>
-
- /**
- * Register binary blocked RDD with given dimensions, default block sizes and no nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaPairRDD
- * @param rlen rows
- * @param clen columns
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, long rlen, long clen) throws DMLRuntimeException {
- //TODO replace default blocksize
- registerInput(varName, rdd, rlen, clen, OptimizerUtils.DEFAULT_BLOCKSIZE, OptimizerUtils.DEFAULT_BLOCKSIZE);
- }
-
- /**
- * Register binary blocked RDD with given dimensions, given block sizes and no nnz
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaPairRDD
- * @param rlen rows
- * @param clen columns
- * @param brlen block rows
- * @param bclen block columns
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, long rlen, long clen, int brlen, int bclen) throws DMLRuntimeException {
- registerInput(varName, rdd, rlen, clen, brlen, bclen, -1);
- }
-
-
- /**
- * Register binary blocked RDD with given dimensions, given block sizes and given nnz (preferred).
- * <p>
- * Marks the variable in the DML script as input variable.
- * Note that this expects a "varName = read(...)" statement in the DML script which through non-MLContext invocation
- * would have been created by reading a HDFS file.
- * @param varName variable name
- * @param rdd the JavaPairRDD
- * @param rlen rows
- * @param clen columns
- * @param brlen block rows
- * @param bclen block columns
- * @param nnz non-zeros
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, long rlen, long clen, int brlen, int bclen, long nnz) throws DMLRuntimeException {
- if(rlen == -1 || clen == -1) {
- throw new DMLRuntimeException("The metadata is required in registerInput for binary format");
- }
-
- MatrixCharacteristics mc = new MatrixCharacteristics(rlen, clen, brlen, bclen, nnz);
- registerInput(varName, rdd, mc);
- }
-
- // All binary blocked method call this.
- public void registerInput(String varName, JavaPairRDD<MatrixIndexes,MatrixBlock> rdd, MatrixCharacteristics mc) throws DMLRuntimeException {
- if(_variables == null)
- _variables = new LocalVariableMap();
- if(_inVarnames == null)
- _inVarnames = new ArrayList<String>();
- // Bug in Spark is messing up blocks and indexes due to too eager reuse of data structures
- JavaPairRDD<MatrixIndexes, MatrixBlock> copyRDD = SparkUtils.copyBinaryBlockMatrix(rdd);
-
- MatrixObject mo = new MatrixObject(ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(),
- new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
- mo.setRDDHandle(new RDDObject(copyRDD, varName));
- _variables.put(varName, mo);
- _inVarnames.add(varName);
- checkIfRegisteringInputAllowed();
- }
-
- public void registerInput(String varName, MatrixBlock mb) throws DMLRuntimeException {
- int blksz = ConfigurationManager.getBlocksize();
- MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), blksz, blksz, mb.getNonZeros());
- registerInput(varName, mb, mc);
- }
-
- public void registerInput(String varName, MatrixBlock mb, MatrixCharacteristics mc) throws DMLRuntimeException {
- if(_variables == null)
- _variables = new LocalVariableMap();
- if(_inVarnames == null)
- _inVarnames = new ArrayList<String>();
- MatrixObject mo = new MatrixObject(ValueType.DOUBLE, OptimizerUtils.getUniqueTempFileName(),
- new MatrixFormatMetaData(mc, OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
- mo.acquireModify(mb);
- mo.release();
- _variables.put(varName, mo);
- _inVarnames.add(varName);
- checkIfRegisteringInputAllowed();
- }
-
- // =============================================================================================
-
- /**
- * Marks the variable in the DML script as output variable.
- * Note that this expects a "write(varName, ...)" statement in the DML script which through non-MLContext invocation
- * would have written the matrix to HDFS.
- * @param varName variable name
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public void registerOutput(String varName) throws DMLRuntimeException {
- if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
- throw new DMLRuntimeException("The registerOutput functionality only supported for spark runtime. Please use MLContext(sc) instead of default constructor.");
- }
- if(_outVarnames == null)
- _outVarnames = new ArrayList<String>();
- _outVarnames.add(varName);
- if(_variables == null)
- _variables = new LocalVariableMap();
- }
-
- // =============================================================================================
-
- /**
- * Execute DML script by passing named arguments using specified config file.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @param parsePyDML true if pydml, false otherwise
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, Map<String, String> namedArgs, boolean parsePyDML, String configFilePath) throws IOException, DMLException, ParseException {
- String [] args = new String[namedArgs.size()];
- int i = 0;
- for(Entry<String, String> entry : namedArgs.entrySet()) {
- if(entry.getValue().trim().isEmpty())
- args[i] = entry.getKey() + "=\"" + entry.getValue() + "\"";
- else
- args[i] = entry.getKey() + "=" + entry.getValue();
- i++;
- }
- return compileAndExecuteScript(dmlScriptFilePath, args, true, parsePyDML ? ScriptType.PYDML : ScriptType.DML, configFilePath);
- }
-
- /**
- * Execute DML script by passing named arguments using specified config file.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, Map<String, String> namedArgs, String configFilePath) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, namedArgs, false, configFilePath);
- }
-
- /**
- * Execute DML script by passing named arguments with default configuration.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, Map<String, String> namedArgs) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, namedArgs, false, null);
- }
-
- /**
- * Execute DML script by passing named arguments.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, scala.collection.immutable.Map<String, String> namedArgs) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)));
- }
-
- /**
- * Experimental: Execute PyDML script by passing named arguments if parsePyDML=true.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @param parsePyDML true if pydml, false otherwise
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, Map<String, String> namedArgs, boolean parsePyDML) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, namedArgs, parsePyDML, null);
- }
-
- /**
- * Experimental: Execute PyDML script by passing named arguments if parsePyDML=true.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param namedArgs named arguments
- * @param parsePyDML true if pydml, false otherwise
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, scala.collection.immutable.Map<String, String> namedArgs, boolean parsePyDML) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), parsePyDML);
- }
-
- /**
- * Execute DML script by passing positional arguments using specified config file
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, String [] args, String configFilePath) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, args, false, configFilePath);
- }
-
- /**
- * Execute DML script by passing positional arguments using specified config file
- * This method is implemented for compatibility with Python MLContext.
- * Java/Scala users should use 'MLOutput execute(String dmlScriptFilePath, String [] args, String configFilePath)' instead as
- * equivalent scala collections (Seq/ArrayBuffer) is not implemented.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> args, String configFilePath) throws IOException, DMLException, ParseException {
- String [] argsArr = new String[args.size()];
- argsArr = args.toArray(argsArr);
- return execute(dmlScriptFilePath, argsArr, false, configFilePath);
- }
-
- /**
- * Execute DML script by passing positional arguments using default configuration
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, String [] args) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, args, false, null);
- }
-
- /**
- * Execute DML script by passing positional arguments using default configuration.
- * This method is implemented for compatibility with Python MLContext.
- * Java/Scala users should use 'MLOutput execute(String dmlScriptFilePath, String [] args)' instead as
- * equivalent scala collections (Seq/ArrayBuffer) is not implemented.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> args) throws IOException, DMLException, ParseException {
- String [] argsArr = new String[args.size()];
- argsArr = args.toArray(argsArr);
- return execute(dmlScriptFilePath, argsArr, false, null);
- }
-
- /**
- * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using default configuration.
- * This method is implemented for compatibility with Python MLContext.
- * Java/Scala users should use 'MLOutput execute(String dmlScriptFilePath, String [] args, boolean parsePyDML)' instead as
- * equivalent scala collections (Seq/ArrayBuffer) is not implemented.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param parsePyDML true if pydml, false otherwise
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> args, boolean parsePyDML) throws IOException, DMLException, ParseException {
- String [] argsArr = new String[args.size()];
- argsArr = args.toArray(argsArr);
- return execute(dmlScriptFilePath, argsArr, parsePyDML, null);
- }
-
- /**
- * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using specified config file.
- * This method is implemented for compatibility with Python MLContext.
- * Java/Scala users should use 'MLOutput execute(String dmlScriptFilePath, String [] args, boolean parsePyDML, String configFilePath)' instead as
- * equivalent scala collections (Seq/ArrayBuffer) is not implemented.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param parsePyDML true if pydml, false otherwise
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> args, boolean parsePyDML, String configFilePath) throws IOException, DMLException, ParseException {
- String [] argsArr = new String[args.size()];
- argsArr = args.toArray(argsArr);
- return execute(dmlScriptFilePath, argsArr, parsePyDML, configFilePath);
- }
-
- /*
- @NOTE: from calling with the SparkR , somehow Map passing from R to java
- is not working and hence we pass in two arrays each representing keys
- and values
- */
- /**
- * Execute DML script by passing positional arguments using specified config file
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param argsName argument names
- * @param argsValues argument values
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName,
- ArrayList<String> argsValues, String configFilePath)
- throws IOException, DMLException, ParseException {
- HashMap<String, String> newNamedArgs = new HashMap<String, String>();
- if (argsName.size() != argsValues.size()) {
- throw new DMLException("size of argsName " + argsName.size() +
- " is diff than " + " size of argsValues");
- }
- for (int i = 0; i < argsName.size(); i++) {
- String k = argsName.get(i);
- String v = argsValues.get(i);
- newNamedArgs.put(k, v);
- }
- return execute(dmlScriptFilePath, newNamedArgs, configFilePath);
- }
- /**
- * Execute DML script by passing positional arguments using specified config file
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param argsName argument names
- * @param argsValues argument values
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, ArrayList<String> argsName,
- ArrayList<String> argsValues)
- throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, argsName, argsValues, null);
- }
-
- /**
- * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using specified config file.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param parsePyDML true if pydml, false otherwise
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, String [] args, boolean parsePyDML, String configFilePath) throws IOException, DMLException, ParseException {
- return compileAndExecuteScript(dmlScriptFilePath, args, false, parsePyDML ? ScriptType.PYDML : ScriptType.DML, configFilePath);
- }
-
- /**
- * Experimental: Execute DML script by passing positional arguments if parsePyDML=true, using default configuration.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param args arguments
- * @param parsePyDML true if pydml, false otherwise
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, String [] args, boolean parsePyDML) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, args, parsePyDML, null);
- }
-
- /**
- * Execute DML script without any arguments using specified config path
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, String configFilePath) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, false, configFilePath);
- }
-
- /**
- * Execute DML script without any arguments using default configuration.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, false, null);
- }
-
- /**
- * Experimental: Execute DML script without any arguments if parsePyDML=true, using specified config path.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param parsePyDML true if pydml, false otherwise
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, boolean parsePyDML, String configFilePath) throws IOException, DMLException, ParseException {
- return compileAndExecuteScript(dmlScriptFilePath, null, false, parsePyDML ? ScriptType.PYDML : ScriptType.DML, configFilePath);
- }
-
- /**
- * Experimental: Execute DML script without any arguments if parsePyDML=true, using default configuration.
- * @param dmlScriptFilePath the dml script can be in local filesystem or in HDFS
- * @param parsePyDML true if pydml, false otherwise
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput execute(String dmlScriptFilePath, boolean parsePyDML) throws IOException, DMLException, ParseException {
- return execute(dmlScriptFilePath, parsePyDML, null);
- }
-
- // -------------------------------- Utility methods begins ----------------------------------------------------------
-
-
- /**
- * Call this method if you want to clear any RDDs set via registerInput, registerOutput.
- * This is required if ml.execute(..) has been called earlier and you want to call a new DML script.
- * Note: By default this doesnot clean up configuration set using setConfig method.
- * To clean the configuration as along with registered input/outputs, please use reset(true);
- * @throws DMLRuntimeException if DMLException occurs
- */
- public void reset()
- throws DMLRuntimeException
- {
- reset(false);
- }
-
- public void reset(boolean cleanupConfig)
- throws DMLRuntimeException
- {
- //cleanup variables from bufferpool, incl evicted files
- //(otherwise memory leak because bufferpool holds references)
- CacheableData.cleanupCacheDir();
-
- //clear mlcontext state
- _inVarnames = null;
- _outVarnames = null;
- _variables = null;
- if(cleanupConfig)
- _additionalConfigs.clear();
- }
-
- /**
- * Used internally
- * @param source the expression
- * @param target the target
- * @throws LanguageException if LanguageException occurs
- */
- void setAppropriateVarsForRead(Expression source, String target)
- throws LanguageException
- {
- boolean isTargetRegistered = isRegisteredAsInput(target);
- boolean isReadExpression = (source instanceof DataExpression && ((DataExpression) source).isRead());
- if(isTargetRegistered && isReadExpression) {
- // Do not check metadata file for registered reads
- ((DataExpression) source).setCheckMetadata(false);
-
- if (((DataExpression)source).getDataType() == Expression.DataType.MATRIX) {
-
- MatrixObject mo = null;
-
- try {
- mo = getMatrixObject(target);
- int blp = source.getBeginLine(); int bcp = source.getBeginColumn();
- int elp = source.getEndLine(); int ecp = source.getEndColumn();
- ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.READNUMNONZEROPARAM, new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp));
-
- if(mo.getMetaData() instanceof MatrixFormatMetaData) {
- MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
- if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.ROWBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
- }
- else {
- throw new LanguageException("Unsupported format through MLContext");
- }
- }
- } catch (DMLRuntimeException e) {
- throw new LanguageException(e);
- }
- } else if (((DataExpression)source).getDataType() == Expression.DataType.FRAME) {
- FrameObject mo = null;
- try {
- mo = getFrameObject(target);
- int blp = source.getBeginLine(); int bcp = source.getBeginColumn();
- int elp = source.getEndLine(); int ecp = source.getEndColumn();
- ((DataExpression) source).addVarParam(DataExpression.READROWPARAM, new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.READCOLPARAM, new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("frame", source.getFilename(), blp, bcp, elp, ecp));
- ((DataExpression) source).addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), blp, bcp, elp, ecp)); //TODO change to schema
-
- if(mo.getMetaData() instanceof MatrixFormatMetaData) {
- MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData();
- if(metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp));
- }
- else if(metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) {
- ((DataExpression) source).addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier(DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp));
- }
- else {
- throw new LanguageException("Unsupported format through MLContext");
- }
- }
- } catch (DMLRuntimeException e) {
- throw new LanguageException(e);
- }
- }
- }
- }
-
- /**
- * Used internally
- * @param tmp list of instructions
- * @return list of instructions
- */
- ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp) {
- String [] outputs = (_outVarnames != null) ? _outVarnames.toArray(new String[0]) : new String[0];
- return JMLCUtils.cleanupRuntimeInstructions(tmp, outputs);
- }
-
- // -------------------------------- Utility methods ends ----------------------------------------------------------
-
- // -------------------------------- Private methods begins ----------------------------------------------------------
- private boolean isRegisteredAsInput(String varName) {
- if(_inVarnames != null) {
- for(String v : _inVarnames) {
- if(v.equals(varName)) {
- return true;
- }
- }
- }
- return false;
- }
-
- private MatrixObject getMatrixObject(String varName) throws DMLRuntimeException {
- if(_variables != null) {
- Data mo = _variables.get(varName);
- if(mo instanceof MatrixObject) {
- return (MatrixObject) mo;
- }
- else {
- throw new DMLRuntimeException("ERROR: Incorrect type");
- }
- }
- throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName);
- }
-
- private FrameObject getFrameObject(String varName) throws DMLRuntimeException {
- if(_variables != null) {
- Data mo = _variables.get(varName);
- if(mo instanceof FrameObject) {
- return (FrameObject) mo;
- }
- else {
- throw new DMLRuntimeException("ERROR: Incorrect type");
- }
- }
- throw new DMLRuntimeException("ERROR: getMatrixObject not set for variable:" + varName);
- }
-
- private int compareVersion(String versionStr1, String versionStr2) {
- Scanner s1 = null;
- Scanner s2 = null;
- try {
- s1 = new Scanner(versionStr1); s1.useDelimiter("\\.");
- s2 = new Scanner(versionStr2); s2.useDelimiter("\\.");
- while(s1.hasNextInt() && s2.hasNextInt()) {
- int version1 = s1.nextInt();
- int version2 = s2.nextInt();
- if(version1 < version2) {
- return -1;
- } else if(version1 > version2) {
- return 1;
- }
- }
-
- if(s1.hasNextInt()) return 1;
- }
- finally {
- IOUtilFunctions.closeSilently(s1);
- IOUtilFunctions.closeSilently(s2);
- }
-
- return 0;
- }
-
- private void initializeSpark(SparkContext sc, boolean monitorPerformance, boolean setForcedSparkExecType) throws DMLRuntimeException {
- MLContextProxy.setActive(true);
-
- this._sc = sc;
-
- if(compareVersion(sc.version(), "1.3.0") < 0 ) {
- throw new DMLRuntimeException("Expected spark version >= 1.3.0 for running SystemML");
- }
-
- if(setForcedSparkExecType)
- DMLScript.rtplatform = RUNTIME_PLATFORM.SPARK;
- else
- DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK;
- }
-
-
- /**
- * Execute a script stored in a string.
- *
- * @param dmlScript the script
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLOutput executeScript(String dmlScript)
- throws IOException, DMLException {
- return executeScript(dmlScript, false);
- }
-
- public MLOutput executeScript(String dmlScript, boolean isPyDML)
- throws IOException, DMLException {
- return executeScript(dmlScript, isPyDML, null);
- }
-
- public MLOutput executeScript(String dmlScript, String configFilePath)
- throws IOException, DMLException {
- return executeScript(dmlScript, false, configFilePath);
- }
-
- public MLOutput executeScript(String dmlScript, boolean isPyDML, String configFilePath)
- throws IOException, DMLException {
- return compileAndExecuteScript(dmlScript, null, false, false, isPyDML ? ScriptType.PYDML : ScriptType.DML, configFilePath);
- }
-
- /*
- @NOTE: from calling with the SparkR , somehow HashMap passing from R to java
- is not working and hence we pass in two arrays each representing keys
- and values
- */
- public MLOutput executeScript(String dmlScript, ArrayList<String> argsName,
- ArrayList<String> argsValues, String configFilePath)
- throws IOException, DMLException, ParseException {
- HashMap<String, String> newNamedArgs = new HashMap<String, String>();
- if (argsName.size() != argsValues.size()) {
- throw new DMLException("size of argsName " + argsName.size() +
- " is diff than " + " size of argsValues");
- }
- for (int i = 0; i < argsName.size(); i++) {
- String k = argsName.get(i);
- String v = argsValues.get(i);
- newNamedArgs.put(k, v);
- }
- return executeScript(dmlScript, newNamedArgs, configFilePath);
- }
-
- public MLOutput executeScript(String dmlScript, ArrayList<String> argsName,
- ArrayList<String> argsValues)
- throws IOException, DMLException, ParseException {
- return executeScript(dmlScript, argsName, argsValues, null);
- }
-
-
- public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs)
- throws IOException, DMLException {
- return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), null);
- }
-
- public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, boolean isPyDML)
- throws IOException, DMLException {
- return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), isPyDML, null);
- }
-
- public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, String configFilePath)
- throws IOException, DMLException {
- return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), configFilePath);
- }
-
- public MLOutput executeScript(String dmlScript, scala.collection.immutable.Map<String, String> namedArgs, boolean isPyDML, String configFilePath)
- throws IOException, DMLException {
- return executeScript(dmlScript, new HashMap<String, String>(scala.collection.JavaConversions.mapAsJavaMap(namedArgs)), isPyDML, configFilePath);
- }
-
- public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs)
- throws IOException, DMLException {
- return executeScript(dmlScript, namedArgs, null);
- }
-
- public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, boolean isPyDML)
- throws IOException, DMLException {
- return executeScript(dmlScript, namedArgs, isPyDML, null);
- }
-
- public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, String configFilePath)
- throws IOException, DMLException {
- return executeScript(dmlScript, namedArgs, false, configFilePath);
- }
-
- public MLOutput executeScript(String dmlScript, Map<String, String> namedArgs, boolean isPyDML, String configFilePath)
- throws IOException, DMLException {
- String [] args = new String[namedArgs.size()];
- int i = 0;
- for(Entry<String, String> entry : namedArgs.entrySet()) {
- if(entry.getValue().trim().isEmpty())
- args[i] = entry.getKey() + "=\"" + entry.getValue() + "\"";
- else
- args[i] = entry.getKey() + "=" + entry.getValue();
- i++;
- }
- return compileAndExecuteScript(dmlScript, args, false, true, isPyDML ? ScriptType.PYDML : ScriptType.DML, configFilePath);
- }
-
- private void checkIfRegisteringInputAllowed() throws DMLRuntimeException {
- if(!(DMLScript.rtplatform == RUNTIME_PLATFORM.SPARK || DMLScript.rtplatform == RUNTIME_PLATFORM.HYBRID_SPARK)) {
- throw new DMLRuntimeException("ERROR: registerInput is only allowed for spark execution mode");
- }
- }
-
- private MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] args, boolean isNamedArgument, ScriptType scriptType, String configFilePath) throws IOException, DMLException {
- return compileAndExecuteScript(dmlScriptFilePath, args, true, isNamedArgument, scriptType, configFilePath);
- }
-
- /**
- * All the execute() methods call this, which after setting appropriate input/output variables
- * calls _compileAndExecuteScript
- * We have explicitly synchronized this function because MLContext/SystemML does not yet support multi-threading.
- * @throws ParseException if ParseException occurs
- * @param dmlScriptFilePath script file path
- * @param args arguments
- * @param isFile whether the string is a path
- * @param isNamedArgument is named argument
- * @param scriptType type of script (DML or PyDML)
- * @param configFilePath path to config file
- * @return output as MLOutput
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- */
- private synchronized MLOutput compileAndExecuteScript(String dmlScriptFilePath, String [] args, boolean isFile, boolean isNamedArgument, ScriptType scriptType, String configFilePath) throws IOException, DMLException {
- try {
-
- DMLScript.SCRIPT_TYPE = scriptType;
-
- if(getActiveMLContext() != null) {
- throw new DMLRuntimeException("SystemML (and hence by definition MLContext) doesnot support parallel execute() calls from same or different MLContexts. "
- + "As a temporary fix, please do explicit synchronization, i.e. synchronized(MLContext.class) { ml.execute(...) } ");
- }
-
- // Set active MLContext.
- _activeMLContext = this;
-
- if( OptimizerUtils.isSparkExecutionMode() ) {
- // Depending on whether registerInput/registerOutput was called initialize the variables
- String[] inputs = (_inVarnames != null) ? _inVarnames.toArray(new String[0]) : new String[0];
- String[] outputs = (_outVarnames != null) ? _outVarnames.toArray(new String[0]) : new String[0];
- Map<String, JavaPairRDD<?,?>> retVal = (_outVarnames!=null && !_outVarnames.isEmpty()) ?
- retVal = new HashMap<String, JavaPairRDD<?,?>>() : null;
- Map<String, MatrixCharacteristics> outMetadata = new HashMap<String, MatrixCharacteristics>();
- Map<String, String> argVals = DMLScript.createArgumentsMap(isNamedArgument, args);
-
- // Run the DML script
- ExecutionContext ec = executeUsingSimplifiedCompilationChain(dmlScriptFilePath, isFile, argVals, scriptType, inputs, outputs, _variables, configFilePath);
- SparkExecutionContext sec = (SparkExecutionContext) ec;
-
- // Now collect the output
- if(_outVarnames != null) {
- if(_variables == null)
- throw new DMLRuntimeException("The symbol table returned after executing the script is empty");
-
- for( String ovar : _outVarnames ) {
- if( !_variables.keySet().contains(ovar) )
- throw new DMLException("Error: The variable " + ovar + " is not available as output after the execution of the DMLScript.");
-
- retVal.put(ovar, sec.getRDDHandleForVariable(ovar, InputInfo.BinaryBlockInputInfo));
- outMetadata.put(ovar, ec.getMatrixCharacteristics(ovar)); // For converting output to dataframe
- }
- }
-
- return new MLOutput(retVal, outMetadata);
- }
- else {
- throw new DMLRuntimeException("Unsupported runtime:" + DMLScript.rtplatform.name());
- }
- }
- finally {
- // Remove global dml config and all thread-local configs
- // TODO enable cleanup whenever invalid GNMF MLcontext is fixed
- // (the test is invalid because it assumes that status of previous execute is kept)
- //ConfigurationManager.setGlobalConfig(new DMLConfig());
- //ConfigurationManager.clearLocalConfigs();
-
- // Reset active MLContext.
- _activeMLContext = null;
- }
- }
-
-
- /**
- * This runs the DML script and returns the ExecutionContext for the caller to extract the output variables.
- * The caller (which is compileAndExecuteScript) is expected to set inputSymbolTable with appropriate matrix representation (RDD, MatrixObject).
- *
- * @param dmlScriptFilePath script file path
- * @param isFile true if file, false otherwise
- * @param argVals map of args
- * @param scriptType type of script (DML or PyDML)
- * @param inputs the inputs
- * @param outputs the outputs
- * @param inputSymbolTable the input symbol table
- * @param configFilePath path to config file
- * @return the execution context
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- private ExecutionContext executeUsingSimplifiedCompilationChain(String dmlScriptFilePath, boolean isFile, Map<String, String> argVals, ScriptType scriptType,
- String[] inputs, String[] outputs, LocalVariableMap inputSymbolTable, String configFilePath)
- throws IOException, DMLException
- {
- //construct dml configuration
- DMLConfig config = (configFilePath == null) ? new DMLConfig() : new DMLConfig(configFilePath);
- for(Entry<String, String> param : _additionalConfigs.entrySet()) {
- config.setTextValue(param.getKey(), param.getValue());
- }
-
- //set global dml and specialized compiler configurations
- ConfigurationManager.setGlobalConfig(config);
- CompilerConfig cconf = new CompilerConfig();
- cconf.set(ConfigType.IGNORE_UNSPECIFIED_ARGS, true);
- cconf.set(ConfigType.REJECT_READ_WRITE_UNKNOWNS, false);
- cconf.set(ConfigType.ALLOW_CSE_PERSISTENT_READS, false);
- ConfigurationManager.setGlobalConfig(cconf);
-
- //read dml script string
- String dmlScriptStr = DMLScript.readDMLScript( isFile, dmlScriptFilePath);
-
- //simplified compilation chain
- _rtprog = null;
-
- //parsing
- ParserWrapper parser = ParserFactory.createParser(scriptType);
- DMLProgram prog;
- if (isFile) {
- prog = parser.parse(dmlScriptFilePath, null, argVals);
- } else {
- prog = parser.parse(null, dmlScriptStr, argVals);
- }
-
- //language validate
- DMLTranslator dmlt = new DMLTranslator(prog);
- dmlt.liveVariableAnalysis(prog);
- dmlt.validateParseTree(prog);
-
- //hop construct/rewrite
- dmlt.constructHops(prog);
- dmlt.rewriteHopsDAG(prog);
-
- Explain.explain(prog);
-
- //rewrite persistent reads/writes
- if(inputSymbolTable != null) {
- RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs, inputSymbolTable);
- ProgramRewriter rewriter2 = new ProgramRewriter(rewrite);
- rewriter2.rewriteProgramHopDAGs(prog);
- }
-
- //lop construct and runtime prog generation
- dmlt.constructLops(prog);
- _rtprog = prog.getRuntimeProgram(config);
-
- //optional global data flow optimization
- if(OptimizerUtils.isOptLevel(OptimizationLevel.O4_GLOBAL_TIME_MEMORY) ) {
- _rtprog = GlobalOptimizerWrapper.optimizeProgram(prog, _rtprog);
- }
-
- // launch SystemML appmaster not required as it is already launched
-
- //count number compiled MR jobs / SP instructions
- ExplainCounts counts = Explain.countDistributedOperations(_rtprog);
- Statistics.resetNoOfCompiledJobs( counts.numJobs );
-
- // Initialize caching and scratch space
- DMLScript.initHadoopExecution(config);
-
- //final cleanup runtime prog
- JMLCUtils.cleanupRuntimeProgram(_rtprog, outputs);
-
- //create and populate execution context
- ExecutionContext ec = ExecutionContextFactory.createContext(_rtprog);
- if(inputSymbolTable != null) {
- ec.setVariables(inputSymbolTable);
- }
-
- //core execute runtime program
- _rtprog.execute( ec );
-
- return ec;
- }
-
- // -------------------------------- Private methods ends ----------------------------------------------------------
-
- // TODO: Add additional create to provide sep, missing values, etc. for CSV
- /**
- * Experimental API: Might be discontinued in future release
- * @param sparkSession the Spark Session
- * @param filePath the file path
- * @param format the format
- * @return the MLMatrix
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLMatrix read(SparkSession sparkSession, String filePath, String format) throws IOException, DMLException, ParseException {
- this.reset();
- this.registerOutput("output");
- MLOutput out = this.executeScript("output = read(\"" + filePath + "\", format=\"" + format + "\"); " + MLMatrix.writeStmt);
- JavaPairRDD<MatrixIndexes, MatrixBlock> blocks = out.getBinaryBlockedRDD("output");
- MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
- return MLMatrix.createMLMatrix(this, sparkSession, blocks, mcOut);
- }
-
- /**
- * Experimental API: Might be discontinued in future release
- * @param sqlContext the SQL Context
- * @param filePath the file path
- * @param format the format
- * @return the MLMatrix
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- * @throws ParseException if ParseException occurs
- */
- public MLMatrix read(SQLContext sqlContext, String filePath, String format) throws IOException, DMLException, ParseException {
- SparkSession sparkSession = sqlContext.sparkSession();
- return read(sparkSession, filePath, format);
- }
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/MLContextProxy.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContextProxy.java b/src/main/java/org/apache/sysml/api/MLContextProxy.java
index db87230..18b2eaa 100644
--- a/src/main/java/org/apache/sysml/api/MLContextProxy.java
+++ b/src/main/java/org/apache/sysml/api/MLContextProxy.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -21,6 +21,7 @@ package org.apache.sysml.api;
import java.util.ArrayList;
+import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLContextException;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.LanguageException;
@@ -31,59 +32,42 @@ import org.apache.sysml.runtime.instructions.Instruction;
* which would try to load spark libraries and hence fail if these are not available. This
* indirection is much more efficient than catching NoClassDefFoundErrors for every access
* to MLContext (e.g., on each recompile).
- *
+ *
*/
-public class MLContextProxy
+public class MLContextProxy
{
-
+
private static boolean _active = false;
-
+
public static void setActive(boolean flag) {
_active = flag;
}
-
+
public static boolean isActive() {
return _active;
}
- @SuppressWarnings("deprecation")
- public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp)
+ public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp)
{
- if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
- return org.apache.sysml.api.MLContext.getActiveMLContext().performCleanupAfterRecompilation(tmp);
- } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
- return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp);
- }
- return tmp;
+ return MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp);
}
- @SuppressWarnings("deprecation")
- public static void setAppropriateVarsForRead(Expression source, String targetname)
- throws LanguageException
+ public static void setAppropriateVarsForRead(Expression source, String targetname)
+ throws LanguageException
{
- if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
- org.apache.sysml.api.MLContext.getActiveMLContext().setAppropriateVarsForRead(source, targetname);
- } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
- org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname);
- }
+ MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname);
}
- @SuppressWarnings("deprecation")
public static Object getActiveMLContext() {
- if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) {
- return org.apache.sysml.api.MLContext.getActiveMLContext();
- } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
- return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext();
- }
- return null;
+ return MLContext.getActiveMLContext();
}
public static Object getActiveMLContextForAPI() {
- if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) {
- return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext();
+ if (MLContext.getActiveMLContext() != null) {
+ return MLContext.getActiveMLContext();
}
throw new MLContextException("No MLContext object is currently active. Have you created one? "
+ "Hint: in Scala, 'val ml = new MLContext(sc)'", true);
}
-
+
}
[3/4] incubator-systemml git commit: [SYSTEMML-1303] Remove
deprecated old MLContext API
Posted by de...@apache.org.
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/MLMatrix.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLMatrix.java b/src/main/java/org/apache/sysml/api/MLMatrix.java
deleted file mode 100644
index 45f631f..0000000
--- a/src/main/java/org/apache/sysml/api/MLMatrix.java
+++ /dev/null
@@ -1,428 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.sysml.api;
-
-import java.io.IOException;
-import java.util.List;
-
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.rdd.RDD;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.catalyst.encoders.RowEncoder;
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
-import org.apache.spark.sql.execution.QueryExecution;
-import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.hops.OptimizerUtils;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.instructions.spark.functions.GetMIMBFromRow;
-import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-
-import scala.Tuple2;
-
-/**
- * Experimental API: Might be discontinued in future release
- *
- * This class serves four purposes:
- * 1. It allows SystemML to fit nicely in MLPipeline by reducing number of reblocks.
- * 2. It allows users to easily read and write matrices without worrying
- * too much about format, metadata and type of underlying RDDs.
- * 3. It provides mechanism to convert to and from MLLib's BlockedMatrix format
- * 4. It provides off-the-shelf library for Distributed Blocked Matrix and reduces learning curve for using SystemML.
- * However, it is important to know that it is easy to abuse this off-the-shelf library and think it as replacement
- * to writing DML, which it is not. It does not provide any optimization between calls. A simple example
- * of the optimization that is conveniently skipped is: (t(m) %*% m)).
- * Also, note that this library is not thread-safe. The operator precedence is not exactly same as DML (as the precedence is
- * enforced by scala compiler), so please use appropriate brackets to enforce precedence.
-
- import org.apache.sysml.api.{MLContext, MLMatrix}
- val ml = new MLContext(sc)
- val mat1 = ml.read(sparkSession, "V_small.csv", "csv")
- val mat2 = ml.read(sparkSession, "W_small.mtx", "binary")
- val result = mat1.transpose() %*% mat2
- result.write("Result_small.mtx", "text")
-
- * @deprecated This will be removed in SystemML 1.0. Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}
- */
-@Deprecated
-public class MLMatrix extends Dataset<Row> {
- private static final long serialVersionUID = -7005940673916671165L;
-
- protected MatrixCharacteristics mc = null;
- protected MLContext ml = null;
-
- protected MLMatrix(SparkSession sparkSession, LogicalPlan logicalPlan, MLContext ml) {
- super(sparkSession, logicalPlan, RowEncoder.apply(null));
- this.ml = ml;
- }
-
- protected MLMatrix(SQLContext sqlContext, LogicalPlan logicalPlan, MLContext ml) {
- super(sqlContext, logicalPlan, RowEncoder.apply(null));
- this.ml = ml;
- }
-
- protected MLMatrix(SparkSession sparkSession, QueryExecution queryExecution, MLContext ml) {
- super(sparkSession, queryExecution, RowEncoder.apply(null));
- this.ml = ml;
- }
-
- protected MLMatrix(SQLContext sqlContext, QueryExecution queryExecution, MLContext ml) {
- super(sqlContext.sparkSession(), queryExecution, RowEncoder.apply(null));
- this.ml = ml;
- }
-
- // Only used internally to set a new MLMatrix after one of matrix operations.
- // Not to be used externally.
- protected MLMatrix(Dataset<Row> df, MatrixCharacteristics mc, MLContext ml) throws DMLRuntimeException {
- super(df.sparkSession(), df.logicalPlan(), RowEncoder.apply(null));
- this.mc = mc;
- this.ml = ml;
- }
-
- //TODO replace default blocksize
- static String writeStmt = "write(output, \"tmp\", format=\"binary\", rows_in_block=" + OptimizerUtils.DEFAULT_BLOCKSIZE + ", cols_in_block=" + OptimizerUtils.DEFAULT_BLOCKSIZE + ");";
-
- // ------------------------------------------------------------------------------------------------
-
-// /**
-// * Experimental unstable API: Converts our blocked matrix format to MLLib's format
-// * @return
-// */
-// public BlockMatrix toBlockedMatrix() {
-// JavaPairRDD<MatrixIndexes, MatrixBlock> blocks = getRDDLazily(this);
-// RDD<Tuple2<Tuple2<Object, Object>, Matrix>> mllibBlocks = blocks.mapToPair(new GetMLLibBlocks(mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock())).rdd();
-// return new BlockMatrix(mllibBlocks, mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getRows(), mc.getCols());
-// }
-
- // ------------------------------------------------------------------------------------------------
- static MLMatrix createMLMatrix(MLContext ml, SparkSession sparkSession, JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
- RDD<Row> rows = blocks.map(new GetMLBlock()).rdd();
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- return new MLMatrix(sparkSession.createDataFrame(rows.toJavaRDD(), schema), mc, ml);
- }
-
- static MLMatrix createMLMatrix(MLContext ml, SQLContext sqlContext, JavaPairRDD<MatrixIndexes, MatrixBlock> blocks, MatrixCharacteristics mc) throws DMLRuntimeException {
- SparkSession sparkSession = sqlContext.sparkSession();
- return createMLMatrix(ml, sparkSession, blocks, mc);
- }
-
- /**
- * Convenient method to write a MLMatrix.
- *
- * @param filePath the file path
- * @param format the format
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- */
- public void write(String filePath, String format) throws IOException, DMLException {
- ml.reset();
- ml.registerInput("left", this);
- ml.executeScript("left = read(\"\"); output=left; write(output, \"" + filePath + "\", format=\"" + format + "\");");
- }
-
- private double getScalarBuiltinFunctionResult(String fn) throws IOException, DMLException {
- if(fn.equals("nrow") || fn.equals("ncol")) {
- ml.reset();
- ml.registerInput("left", getRDDLazily(this), mc.getRows(), mc.getCols(), mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
- ml.registerOutput("output");
- String script = "left = read(\"\");"
- + "val = " + fn + "(left); "
- + "output = matrix(val, rows=1, cols=1); "
- + writeStmt;
- MLOutput out = ml.executeScript(script);
- List<Tuple2<MatrixIndexes, MatrixBlock>> result = out.getBinaryBlockedRDD("output").collect();
- if(result == null || result.size() != 1) {
- throw new DMLRuntimeException("Error while computing the function: " + fn);
- }
- return result.get(0)._2.getValue(0, 0);
- }
- else {
- throw new DMLRuntimeException("The function " + fn + " is not yet supported in MLMatrix");
- }
- }
-
- /**
- * Gets or computes the number of rows.
- * @return the number of rows
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- */
- public long numRows() throws IOException, DMLException {
- if(mc.rowsKnown()) {
- return mc.getRows();
- }
- else {
- return (long) getScalarBuiltinFunctionResult("nrow");
- }
- }
-
- /**
- * Gets or computes the number of columns.
- * @return the number of columns
- * @throws IOException if IOException occurs
- * @throws DMLException if DMLException occurs
- */
- public long numCols() throws IOException, DMLException {
- if(mc.colsKnown()) {
- return mc.getCols();
- }
- else {
- return (long) getScalarBuiltinFunctionResult("ncol");
- }
- }
-
- public int rowsPerBlock() {
- return mc.getRowsPerBlock();
- }
-
- public int colsPerBlock() {
- return mc.getColsPerBlock();
- }
-
- private String getScript(String binaryOperator) {
- return "left = read(\"\");"
- + "right = read(\"\");"
- + "output = left " + binaryOperator + " right; "
- + writeStmt;
- }
-
- private String getScalarBinaryScript(String binaryOperator, double scalar, boolean isScalarLeft) {
- if(isScalarLeft) {
- return "left = read(\"\");"
- + "output = " + scalar + " " + binaryOperator + " left ;"
- + writeStmt;
- }
- else {
- return "left = read(\"\");"
- + "output = left " + binaryOperator + " " + scalar + ";"
- + writeStmt;
- }
- }
-
- static JavaPairRDD<MatrixIndexes, MatrixBlock> getRDDLazily(MLMatrix mat) {
- return mat.rdd().toJavaRDD().mapToPair(new GetMIMBFromRow());
- }
-
- private MLMatrix matrixBinaryOp(MLMatrix that, String op) throws IOException, DMLException {
-
- if(mc.getRowsPerBlock() != that.mc.getRowsPerBlock() || mc.getColsPerBlock() != that.mc.getColsPerBlock()) {
- throw new DMLRuntimeException("Incompatible block sizes: brlen:" + mc.getRowsPerBlock() + "!=" + that.mc.getRowsPerBlock() + " || bclen:" + mc.getColsPerBlock() + "!=" + that.mc.getColsPerBlock());
- }
-
- if(op.equals("%*%")) {
- if(mc.getCols() != that.mc.getRows()) {
- throw new DMLRuntimeException("Dimensions mismatch:" + mc.getCols() + "!=" + that.mc.getRows());
- }
- }
- else {
- if(mc.getRows() != that.mc.getRows() || mc.getCols() != that.mc.getCols()) {
- throw new DMLRuntimeException("Dimensions mismatch:" + mc.getRows() + "!=" + that.mc.getRows() + " || " + mc.getCols() + "!=" + that.mc.getCols());
- }
- }
-
- ml.reset();
- ml.registerInput("left", this);
- ml.registerInput("right", that);
- ml.registerOutput("output");
- MLOutput out = ml.executeScript(getScript(op));
- RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
- return new MLMatrix(this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
- }
-
- private MLMatrix scalarBinaryOp(Double scalar, String op, boolean isScalarLeft) throws IOException, DMLException {
- ml.reset();
- ml.registerInput("left", this);
- ml.registerOutput("output");
- MLOutput out = ml.executeScript(getScalarBinaryScript(op, scalar, isScalarLeft));
- RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
- return new MLMatrix(this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
- }
-
- // ---------------------------------------------------
- // Simple operator loading but doesnot utilize the optimizer
-
- public MLMatrix $greater(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, ">");
- }
-
- public MLMatrix $less(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "<");
- }
-
- public MLMatrix $greater$eq(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, ">=");
- }
-
- public MLMatrix $less$eq(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "<=");
- }
-
- public MLMatrix $eq$eq(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "==");
- }
-
- public MLMatrix $bang$eq(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "!=");
- }
-
- public MLMatrix $up(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "^");
- }
-
- public MLMatrix exp(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "^");
- }
-
- public MLMatrix $plus(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "+");
- }
-
- public MLMatrix add(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "+");
- }
-
- public MLMatrix $minus(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "-");
- }
-
- public MLMatrix minus(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "-");
- }
-
- public MLMatrix $times(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "*");
- }
-
- public MLMatrix elementWiseMultiply(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "*");
- }
-
- public MLMatrix $div(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "/");
- }
-
- public MLMatrix divide(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "/");
- }
-
- public MLMatrix $percent$div$percent(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%/%");
- }
-
- public MLMatrix integerDivision(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%/%");
- }
-
- public MLMatrix $percent$percent(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%%");
- }
-
- public MLMatrix modulus(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%%");
- }
-
- public MLMatrix $percent$times$percent(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%*%");
- }
-
- public MLMatrix multiply(MLMatrix that) throws IOException, DMLException {
- return matrixBinaryOp(that, "%*%");
- }
-
- public MLMatrix transpose() throws IOException, DMLException {
- ml.reset();
- ml.registerInput("left", this);
- ml.registerOutput("output");
- String script = "left = read(\"\");"
- + "output = t(left); "
- + writeStmt;
- MLOutput out = ml.executeScript(script);
- RDD<Row> rows = out.getBinaryBlockedRDD("output").map(new GetMLBlock()).rdd();
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- MatrixCharacteristics mcOut = out.getMatrixCharacteristics("output");
- return new MLMatrix(this.sparkSession().createDataFrame(rows.toJavaRDD(), schema), mcOut, ml);
- }
-
- // TODO: For 'scalar op matrix' operations: Do implicit conversions
- public MLMatrix $plus(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "+", false);
- }
-
- public MLMatrix add(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "+", false);
- }
-
- public MLMatrix $minus(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "-", false);
- }
-
- public MLMatrix minus(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "-", false);
- }
-
- public MLMatrix $times(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "*", false);
- }
-
- public MLMatrix elementWiseMultiply(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "*", false);
- }
-
- public MLMatrix $div(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "/", false);
- }
-
- public MLMatrix divide(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "/", false);
- }
-
- public MLMatrix $greater(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, ">", false);
- }
-
- public MLMatrix $less(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "<", false);
- }
-
- public MLMatrix $greater$eq(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, ">=", false);
- }
-
- public MLMatrix $less$eq(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "<=", false);
- }
-
- public MLMatrix $eq$eq(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "==", false);
- }
-
- public MLMatrix $bang$eq(Double scalar) throws IOException, DMLException {
- return scalarBinaryOp(scalar, "!=", false);
- }
-
-}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/MLOutput.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLOutput.java b/src/main/java/org/apache/sysml/api/MLOutput.java
deleted file mode 100644
index a16eccd..0000000
--- a/src/main/java/org/apache/sysml/api/MLOutput.java
+++ /dev/null
@@ -1,267 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.sysml.api;
-
-import java.util.Map;
-
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.StructType;
-import org.apache.sysml.runtime.DMLRuntimeException;
-import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
-import org.apache.sysml.runtime.instructions.spark.functions.GetMLBlock;
-import org.apache.sysml.runtime.instructions.spark.utils.FrameRDDConverterUtils;
-import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
-import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
-import org.apache.sysml.runtime.matrix.data.CSVFileFormatProperties;
-import org.apache.sysml.runtime.matrix.data.FrameBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixBlock;
-import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
-
-/**
- * This is a simple container object that returns the output of execute from MLContext
- *
- * @deprecated This will be removed in SystemML 1.0. Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}
- * and {@link org.apache.sysml.api.mlcontext.MLResults}
- */
-@Deprecated
-public class MLOutput {
-
- Map<String, JavaPairRDD<?,?>> _outputs;
- private Map<String, MatrixCharacteristics> _outMetadata = null;
-
- public MLOutput(Map<String, JavaPairRDD<?,?>> outputs, Map<String, MatrixCharacteristics> outMetadata) {
- this._outputs = outputs;
- this._outMetadata = outMetadata;
- }
-
- public MatrixBlock getMatrixBlock(String varName) throws DMLRuntimeException {
- MatrixCharacteristics mc = getMatrixCharacteristics(varName);
- // The matrix block is always pushed to an RDD and then we do collect
- // We can later avoid this by returning symbol table rather than "Map<String, JavaPairRDD<MatrixIndexes,MatrixBlock>> _outputs"
- return SparkExecutionContext.toMatrixBlock(getBinaryBlockedRDD(varName), (int) mc.getRows(), (int) mc.getCols(),
- mc.getRowsPerBlock(), mc.getColsPerBlock(), mc.getNonZeros());
- }
-
- @SuppressWarnings("unchecked")
- public JavaPairRDD<MatrixIndexes,MatrixBlock> getBinaryBlockedRDD(String varName) throws DMLRuntimeException {
- if(_outputs.containsKey(varName)) {
- return (JavaPairRDD<MatrixIndexes,MatrixBlock>) _outputs.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the outputs.");
- }
-
- @SuppressWarnings("unchecked")
- public JavaPairRDD<Long,FrameBlock> getFrameBinaryBlockedRDD(String varName) throws DMLRuntimeException {
- if(_outputs.containsKey(varName)) {
- return (JavaPairRDD<Long,FrameBlock>)_outputs.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the outputs.");
- }
-
- public MatrixCharacteristics getMatrixCharacteristics(String varName) throws DMLRuntimeException {
- if(_outputs.containsKey(varName)) {
- return _outMetadata.get(varName);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- /**
- * Note, the output DataFrame has an additional column ID.
- * An easy way to get DataFrame without ID is by df.drop("__INDEX")
- *
- * @param sparkSession the Spark Session
- * @param varName the variable name
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SparkSession sparkSession, String varName) throws DMLRuntimeException {
- if(sparkSession == null) {
- throw new DMLRuntimeException("SparkSession is not created.");
- }
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = _outMetadata.get(varName);
- return RDDConverterUtils.binaryBlockToDataFrame(sparkSession, rdd, mc, false);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- /**
- * Note, the output DataFrame has an additional column ID.
- * An easy way to get DataFrame without ID is by df.drop("__INDEX")
- *
- * @param sqlContext the SQL Context
- * @param varName the variable name
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SQLContext sqlContext, String varName) throws DMLRuntimeException {
- if (sqlContext == null) {
- throw new DMLRuntimeException("SQLContext is not created");
- }
- SparkSession sparkSession = sqlContext.sparkSession();
- return getDF(sparkSession, varName);
- }
-
- /**
- * Obtain the DataFrame
- *
- * @param sparkSession the Spark Session
- * @param varName the variable name
- * @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.ml.linalg.Vector
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SparkSession sparkSession, String varName, boolean outputVector) throws DMLRuntimeException {
- if(sparkSession == null) {
- throw new DMLRuntimeException("SparkSession is not created.");
- }
- if(outputVector) {
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = _outMetadata.get(varName);
- return RDDConverterUtils.binaryBlockToDataFrame(sparkSession, rdd, mc, true);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
- else {
- return getDF(sparkSession, varName);
- }
-
- }
-
- /**
- * Obtain the DataFrame
- *
- * @param sqlContext the SQL Context
- * @param varName the variable name
- * @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.ml.linalg.Vector
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SQLContext sqlContext, String varName, boolean outputVector) throws DMLRuntimeException {
- if (sqlContext == null) {
- throw new DMLRuntimeException("SQLContext is not created");
- }
- SparkSession sparkSession = sqlContext.sparkSession();
- return getDF(sparkSession, varName, outputVector);
- }
-
- /**
- * This methods improves the performance of MLPipeline wrappers.
- *
- * @param sparkSession the Spark Session
- * @param varName the variable name
- * @param mc the matrix characteristics
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SparkSession sparkSession, String varName, MatrixCharacteristics mc)
- throws DMLRuntimeException
- {
- if(sparkSession == null) {
- throw new DMLRuntimeException("SparkSession is not created.");
- }
- JavaPairRDD<MatrixIndexes,MatrixBlock> binaryBlockRDD = getBinaryBlockedRDD(varName);
- return RDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryBlockRDD, mc, true);
- }
-
- /**
- * This methods improves the performance of MLPipeline wrappers.
- *
- * @param sqlContext the SQL Context
- * @param varName the variable name
- * @param mc the matrix characteristics
- * @return the DataFrame
- * @throws DMLRuntimeException if DMLRuntimeException occurs
- */
- public Dataset<Row> getDF(SQLContext sqlContext, String varName, MatrixCharacteristics mc)
- throws DMLRuntimeException
- {
- if (sqlContext == null) {
- throw new DMLRuntimeException("SQLContext is not created");
- }
- SparkSession sparkSession = sqlContext.sparkSession();
- return getDF(sparkSession, varName, mc);
- }
-
- public JavaRDD<String> getStringRDD(String varName, String format) throws DMLRuntimeException {
- if(format.equals("text")) {
- JavaPairRDD<MatrixIndexes, MatrixBlock> binaryRDD = getBinaryBlockedRDD(varName);
- MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
- return RDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn);
- }
- else {
- throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
- }
- }
-
- public JavaRDD<String> getStringFrameRDD(String varName, String format, CSVFileFormatProperties fprop ) throws DMLRuntimeException {
- JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName);
- MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
- if(format.equals("csv")) {
- return FrameRDDConverterUtils.binaryBlockToCsv(binaryRDD, mcIn, fprop, false);
- }
- else if(format.equals("text")) {
- return FrameRDDConverterUtils.binaryBlockToTextCell(binaryRDD, mcIn);
- }
- else {
- throw new DMLRuntimeException("The output format:" + format + " is not implemented yet.");
- }
-
- }
-
- public Dataset<Row> getDataFrameRDD(String varName, JavaSparkContext jsc) throws DMLRuntimeException {
- JavaPairRDD<Long, FrameBlock> binaryRDD = getFrameBinaryBlockedRDD(varName);
- MatrixCharacteristics mcIn = getMatrixCharacteristics(varName);
- SparkSession sparkSession = SparkSession.builder().sparkContext(jsc.sc()).getOrCreate();
- return FrameRDDConverterUtils.binaryBlockToDataFrame(sparkSession, binaryRDD, mcIn, null);
- }
-
- public MLMatrix getMLMatrix(MLContext ml, SparkSession sparkSession, String varName) throws DMLRuntimeException {
- if(sparkSession == null) {
- throw new DMLRuntimeException("SparkSession is not created.");
- }
- else if(ml == null) {
- throw new DMLRuntimeException("MLContext is not created.");
- }
- JavaPairRDD<MatrixIndexes,MatrixBlock> rdd = getBinaryBlockedRDD(varName);
- if(rdd != null) {
- MatrixCharacteristics mc = getMatrixCharacteristics(varName);
- StructType schema = MLBlock.getDefaultSchemaForBinaryBlock();
- return new MLMatrix(sparkSession.createDataFrame(rdd.map(new GetMLBlock()).rdd(), schema), mc, ml);
- }
- throw new DMLRuntimeException("Variable " + varName + " not found in the output symbol table.");
- }
-
- public MLMatrix getMLMatrix(MLContext ml, SQLContext sqlContext, String varName) throws DMLRuntimeException {
- if (sqlContext == null) {
- throw new DMLRuntimeException("SQLContext is not created");
- }
- SparkSession sparkSession = sqlContext.sparkSession();
- return getMLMatrix(ml, sparkSession, varName);
- }
-}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/7ba17c7f/src/main/java/org/apache/sysml/api/python/SystemML.py
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/python/SystemML.py b/src/main/java/org/apache/sysml/api/python/SystemML.py
deleted file mode 100644
index b22c570..0000000
--- a/src/main/java/org/apache/sysml/api/python/SystemML.py
+++ /dev/null
@@ -1,232 +0,0 @@
-#!/usr/bin/python
-#-------------------------------------------------------------
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-#-------------------------------------------------------------
-
-from __future__ import division
-from py4j.protocol import Py4JJavaError, Py4JError
-import traceback
-import os
-from pyspark.context import SparkContext
-from pyspark.sql import DataFrame, SparkSession
-from pyspark.rdd import RDD
-
-
-class MLContext(object):
-
- """
- Simple wrapper class for MLContext in SystemML.jar
- ...
- Attributes
- ----------
- ml : MLContext
- A reference to the java MLContext
- sc : SparkContext
- The SparkContext that has been specified during initialization
- """
-
- def __init__(self, sc, *args):
- """
- If initialized with a SparkContext, will connect to the Java MLContext
- class.
- args:
- sc (SparkContext): the current SparkContext
- monitor (boolean=False): Whether to monitor the performance
- forceSpark (boolean=False): Whether to force execution on spark
- returns:
- MLContext: Instance of MLContext
- """
-
- try:
- monitorPerformance = (args[0] if len(args) > 0 else False)
- setForcedSparkExecType = (args[1] if len(args) > 1 else False)
- self.sc = sc
- self.ml = sc._jvm.org.apache.sysml.api.MLContext(sc._jsc, monitorPerformance, setForcedSparkExecType)
- self.sparkSession = SparkSession.builder.getOrCreate()
- except Py4JError:
- traceback.print_exc()
-
- def reset(self):
- """
- Call this method of you want to clear any RDDs set via
- registerInput or registerOutput
- """
- try:
- self.ml.reset()
- except Py4JJavaError:
- traceback.print_exc()
-
- def execute(self, dmlScriptFilePath, *args):
- """
- Executes the script in spark-mode by passing the arguments to the
- MLContext java class.
- Returns:
- MLOutput: an instance of the MLOutput-class
- """
- numArgs = len(args) + 1
- try:
- if numArgs == 1:
- jmlOut = self.ml.execute(dmlScriptFilePath)
- mlOut = MLOutput(jmlOut, self.sc)
- return mlOut
- elif numArgs == 2:
- jmlOut = self.ml.execute(dmlScriptFilePath, args[0])
- mlOut = MLOutput(jmlOut, self.sc)
- return mlOut
- elif numArgs == 3:
- jmlOut = self.ml.execute(dmlScriptFilePath, args[0], args[1])
- mlOut = MLOutput(jmlOut, self.sc)
- return mlOut
- elif numArgs == 4:
- jmlOut = self.ml.execute(dmlScriptFilePath, args[0], args[1], args[2])
- mlOut = MLOutput(jmlOut, self.sc)
- return mlOut
- else:
- raise TypeError('Arguments do not match MLContext-API')
- except Py4JJavaError:
- traceback.print_exc()
-
- def executeScript(self, dmlScript, nargs=None, outputs=None, isPyDML=False, configFilePath=None):
- """
- Executes the script in spark-mode by passing the arguments to the
- MLContext java class.
- Returns:
- MLOutput: an instance of the MLOutput-class
- """
- try:
- # Register inputs as needed
- if nargs is not None:
- for key, value in list(nargs.items()):
- if isinstance(value, DataFrame):
- self.registerInput(key, value)
- del nargs[key]
- else:
- nargs[key] = str(value)
- else:
- nargs = {}
-
- # Register outputs as needed
- if outputs is not None:
- for out in outputs:
- self.registerOutput(out)
-
- # Execute script
- jml_out = self.ml.executeScript(dmlScript, nargs, isPyDML, configFilePath)
- ml_out = MLOutput(jml_out, self.sc)
- return ml_out
- except Py4JJavaError:
- traceback.print_exc()
-
- def registerInput(self, varName, src, *args):
- """
- Method to register inputs used by the DML script.
- Supported format:
- 1. DataFrame
- 2. CSV/Text (as JavaRDD<String> or JavaPairRDD<LongWritable, Text>)
- 3. Binary blocked RDD (JavaPairRDD<MatrixIndexes,MatrixBlock>))
- Also overloaded to support metadata information such as format, rlen, clen, ...
- Please note the variable names given below in quotes correspond to the variables in DML script.
- These variables need to have corresponding read/write associated in DML script.
- Currently, only matrix variables are supported through registerInput/registerOutput interface.
- To pass scalar variables, use named/positional arguments (described later) or wrap them into matrix variable.
- """
- numArgs = len(args) + 2
-
- if hasattr(src, '_jdf'):
- rdd = src._jdf
- elif hasattr(src, '_jrdd'):
- rdd = src._jrdd
- else:
- rdd = src
-
- try:
- if numArgs == 2:
- self.ml.registerInput(varName, rdd)
- elif numArgs == 3:
- self.ml.registerInput(varName, rdd, args[0])
- elif numArgs == 4:
- self.ml.registerInput(varName, rdd, args[0], args[1])
- elif numArgs == 5:
- self.ml.registerInput(varName, rdd, args[0], args[1], args[2])
- elif numArgs == 6:
- self.ml.registerInput(varName, rdd, args[0], args[1], args[2], args[3])
- elif numArgs == 7:
- self.ml.registerInput(varName, rdd, args[0], args[1], args[2], args[3], args[4])
- elif numArgs == 10:
- self.ml.registerInput(varName, rdd, args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7])
- else:
- raise TypeError('Arguments do not match MLContext-API')
- except Py4JJavaError:
- traceback.print_exc()
-
- def registerOutput(self, varName):
- """
- Register output variables used in the DML script
- args:
- varName: (String) The name used in the DML script
- """
-
- try:
- self.ml.registerOutput(varName)
- except Py4JJavaError:
- traceback.print_exc()
-
- def getDmlJson(self):
- try:
- return self.ml.getMonitoringUtil().getRuntimeInfoInJSONFormat()
- except Py4JJavaError:
- traceback.print_exc()
-
-
-class MLOutput(object):
-
- """
- This is a simple wrapper object that returns the output of execute from MLContext
- ...
- Attributes
- ----------
- jmlOut MLContext:
- A reference to the MLOutput object through py4j
- """
-
- def __init__(self, jmlOut, sc):
- self.jmlOut = jmlOut
- self.sc = sc
-
- def getBinaryBlockedRDD(self, varName):
- raise Exception('Not supported in Python MLContext')
-
- def getMatrixCharacteristics(self, varName):
- raise Exception('Not supported in Python MLContext')
-
- def getDF(self, sparkSession, varName):
- try:
- jdf = self.jmlOut.getDF(sparkSession, varName)
- df = DataFrame(jdf, sparkSession)
- return df
- except Py4JJavaError:
- traceback.print_exc()
-
- def getMLMatrix(self, sparkSession, varName):
- raise Exception('Not supported in Python MLContext')
-
- def getStringRDD(self, varName, format):
- raise Exception('Not supported in Python MLContext')
-