You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/01/13 01:41:03 UTC

incubator-systemml git commit: Resolving SYSTEMML-331 bug by adding setConfig method in MLContext.

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 6b4205e29 -> 231f69dd0


Resolving SYSTEMML-331 bug by adding setConfig method in MLContext.

One can use following code to change scratch_space for Databricks cloud:
val ml = new MLContext(sc)
...
ml.setConfig("scratch", "some_other_scratch_dir")
...

Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/231f69dd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/231f69dd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/231f69dd

Branch: refs/heads/master
Commit: 231f69dd01021280e89b122782457d2fd7ee3366
Parents: 6b4205e
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Jan 12 16:41:01 2016 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Jan 12 16:41:01 2016 -0800

----------------------------------------------------------------------
 .../java/org/apache/sysml/api/MLContext.java    | 18 +++++++++++++--
 .../java/org/apache/sysml/conf/DMLConfig.java   | 24 +++++++++++++++++++-
 .../functions/mlcontext/GNMFTest.java           | 11 +++++++--
 3 files changed, 48 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index 4b9ad8d..6700f13 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -33,7 +33,6 @@ import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.rdd.RDD;
-
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.jmlc.JMLCUtils;
 import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
@@ -81,7 +80,6 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.utils.Explain;
 import org.apache.sysml.utils.Statistics;
 import org.apache.sysml.utils.Explain.ExplainCounts;
-
 import org.apache.spark.sql.DataFrame;
 import org.apache.spark.sql.SQLContext;
 
@@ -195,6 +193,8 @@ public class MLContext {
 	private LocalVariableMap _variables = null; // temporary symbol table
 	private Program _rtprog = null;
 	
+	private HashMap<String, String> _additionalConfigs = new HashMap<String, String>();
+	
 	// --------------------------------------------------
 	// _monitorUtils is set only when MLContext(sc, true)
 	private SparkMonitoringUtil _monitorUtils = null;
@@ -226,6 +226,15 @@ public class MLContext {
 		initializeSpark(sc.sc(), false, false);
 	}
 	
+	/**
+	 * Allow users to provide custom named-value configuration.
+	 * @param paramName
+	 * @param paramVal
+	 */
+	public void setConfig(String paramName, String paramVal) {
+		_additionalConfigs.put(paramName, paramVal);
+	}
+	
 	// ====================================================================================
 	// Register input APIs
 	// 1. DataFrame
@@ -904,6 +913,7 @@ public class MLContext {
 		_inVarnames = null;
 		_outVarnames = null;
 		_variables = null;
+		_additionalConfigs.clear();
 	}
 	
 	/**
@@ -1262,6 +1272,10 @@ public class MLContext {
 			config = new DMLConfig(configFilePath);
 		}
 		
+		for(Entry<String, String> param : _additionalConfigs.entrySet()) {
+			config.setTextValue(param.getKey(), param.getValue());
+		}
+		
 		ConfigurationManager.setConfig(config);
 		
 		String dmlScriptStr = null;

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 2782487..a7312d0 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -42,7 +42,6 @@ import org.w3c.dom.Document;
 import org.w3c.dom.Element;
 import org.w3c.dom.NodeList;
 import org.xml.sax.SAXException;
-
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.util.LocalFileUtils;
@@ -280,6 +279,29 @@ public class DMLConfig
 			elem.getFirstChild().setNodeValue(newTextValue);	
 		}
 	}
+	
+	/**
+	 * Method to update the key value
+	 * @param paramName
+	 * @param paramValue
+	 */
+	public void setTextValue(String paramName, String paramValue) throws DMLRuntimeException {
+		if(this.xml_root != null)
+			DMLConfig.setTextValue(this.xml_root, paramName, paramValue);
+		else {
+			DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
+			factory.setIgnoringComments(true); //ignore XML comments
+			DocumentBuilder builder;
+			try {
+				builder = factory.newDocumentBuilder();
+				String configString = "<root><" + paramName + ">"+paramValue+"</" + paramName + "></root>";
+				Document domTree = builder.parse(new ByteArrayInputStream(configString.getBytes("UTF-8")));
+				this.xml_root = domTree.getDocumentElement();
+			} catch (Exception e) {
+				throw new DMLRuntimeException("Unable to set config value", e);
+			}
+		}
+	}
 
 	/**
 	 * 

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index a1aa149..ac725a6 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -30,16 +30,17 @@ import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
+import org.junit.Assert;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 import org.junit.runners.Parameterized.Parameters;
-
 import org.apache.sysml.api.DMLException;
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
 import org.apache.sysml.api.MLContext;
 import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.parser.ParseException;
 import org.apache.sysml.runtime.DMLRuntimeException;
 import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -50,7 +51,6 @@ import org.apache.sysml.runtime.util.MapReduceTool;
 import org.apache.sysml.test.integration.AutomatedTestBase;
 import org.apache.sysml.test.utils.TestUtils;
 import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
-
 import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
 import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
 
@@ -172,10 +172,17 @@ public class GNMFTest extends AutomatedTestBase
 			
 			if(numRegisteredOutputs >= 2) {
 				mlCtx.registerOutput("W");
+				mlCtx.setConfig("cp.parallel.matrixmult", "false");
 			}
 			
 			MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs);
 			
+			if(numRegisteredOutputs >= 2) {
+				String configStr = ConfigurationManager.getConfig().getConfigInfo();
+				if(configStr.contains("cp.parallel.matrixmult: true"))
+					Assert.fail("Configuration not updated via setConfig");
+			}
+			
 			if(numRegisteredOutputs >= 1) {
 				JavaRDD<String> hOut = out.getStringRDD("H", "text");
 				String fName = output("h");