You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2016/01/13 01:41:03 UTC
incubator-systemml git commit: Resolving SYSTEMML-331 bug by adding
setConfig method in MLContext.
Repository: incubator-systemml
Updated Branches:
refs/heads/master 6b4205e29 -> 231f69dd0
Resolving SYSTEMML-331 bug by adding setConfig method in MLContext.
One can use following code to change scratch_space for Databricks cloud:
val ml = new MLContext(sc)
...
ml.setConfig("scratch", "some_other_scratch_dir")
...
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/231f69dd
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/231f69dd
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/231f69dd
Branch: refs/heads/master
Commit: 231f69dd01021280e89b122782457d2fd7ee3366
Parents: 6b4205e
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue Jan 12 16:41:01 2016 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue Jan 12 16:41:01 2016 -0800
----------------------------------------------------------------------
.../java/org/apache/sysml/api/MLContext.java | 18 +++++++++++++--
.../java/org/apache/sysml/conf/DMLConfig.java | 24 +++++++++++++++++++-
.../functions/mlcontext/GNMFTest.java | 11 +++++++--
3 files changed, 48 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/main/java/org/apache/sysml/api/MLContext.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java
index 4b9ad8d..6700f13 100644
--- a/src/main/java/org/apache/sysml/api/MLContext.java
+++ b/src/main/java/org/apache/sysml/api/MLContext.java
@@ -33,7 +33,6 @@ import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
-
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.api.jmlc.JMLCUtils;
import org.apache.sysml.api.monitoring.SparkMonitoringUtil;
@@ -81,7 +80,6 @@ import org.apache.sysml.runtime.matrix.data.OutputInfo;
import org.apache.sysml.utils.Explain;
import org.apache.sysml.utils.Statistics;
import org.apache.sysml.utils.Explain.ExplainCounts;
-
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
@@ -195,6 +193,8 @@ public class MLContext {
private LocalVariableMap _variables = null; // temporary symbol table
private Program _rtprog = null;
+ private HashMap<String, String> _additionalConfigs = new HashMap<String, String>();
+
// --------------------------------------------------
// _monitorUtils is set only when MLContext(sc, true)
private SparkMonitoringUtil _monitorUtils = null;
@@ -226,6 +226,15 @@ public class MLContext {
initializeSpark(sc.sc(), false, false);
}
+ /**
+ * Allow users to provide custom named-value configuration.
+ * @param paramName
+ * @param paramVal
+ */
+ public void setConfig(String paramName, String paramVal) {
+ _additionalConfigs.put(paramName, paramVal);
+ }
+
// ====================================================================================
// Register input APIs
// 1. DataFrame
@@ -904,6 +913,7 @@ public class MLContext {
_inVarnames = null;
_outVarnames = null;
_variables = null;
+ _additionalConfigs.clear();
}
/**
@@ -1262,6 +1272,10 @@ public class MLContext {
config = new DMLConfig(configFilePath);
}
+ for(Entry<String, String> param : _additionalConfigs.entrySet()) {
+ config.setTextValue(param.getKey(), param.getValue());
+ }
+
ConfigurationManager.setConfig(config);
String dmlScriptStr = null;
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/main/java/org/apache/sysml/conf/DMLConfig.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/conf/DMLConfig.java b/src/main/java/org/apache/sysml/conf/DMLConfig.java
index 2782487..a7312d0 100644
--- a/src/main/java/org/apache/sysml/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysml/conf/DMLConfig.java
@@ -42,7 +42,6 @@ import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;
-
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.util.LocalFileUtils;
@@ -280,6 +279,29 @@ public class DMLConfig
elem.getFirstChild().setNodeValue(newTextValue);
}
}
+
+ /**
+ * Method to update the key value
+ * @param paramName
+ * @param paramValue
+ */
+ public void setTextValue(String paramName, String paramValue) throws DMLRuntimeException {
+ if(this.xml_root != null)
+ DMLConfig.setTextValue(this.xml_root, paramName, paramValue);
+ else {
+ DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
+ factory.setIgnoringComments(true); //ignore XML comments
+ DocumentBuilder builder;
+ try {
+ builder = factory.newDocumentBuilder();
+ String configString = "<root><" + paramName + ">"+paramValue+"</" + paramName + "></root>";
+ Document domTree = builder.parse(new ByteArrayInputStream(configString.getBytes("UTF-8")));
+ this.xml_root = domTree.getDocumentElement();
+ } catch (Exception e) {
+ throw new DMLRuntimeException("Unable to set config value", e);
+ }
+ }
+ }
/**
*
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/231f69dd/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
index a1aa149..ac725a6 100644
--- a/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
+++ b/src/test/java/org/apache/sysml/test/integration/functions/mlcontext/GNMFTest.java
@@ -30,16 +30,17 @@ import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
+import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
-
import org.apache.sysml.api.DMLException;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM;
import org.apache.sysml.api.MLContext;
import org.apache.sysml.api.MLOutput;
+import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.parser.ParseException;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
@@ -50,7 +51,6 @@ import org.apache.sysml.runtime.util.MapReduceTool;
import org.apache.sysml.test.integration.AutomatedTestBase;
import org.apache.sysml.test.utils.TestUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
-
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
@@ -172,10 +172,17 @@ public class GNMFTest extends AutomatedTestBase
if(numRegisteredOutputs >= 2) {
mlCtx.registerOutput("W");
+ mlCtx.setConfig("cp.parallel.matrixmult", "false");
}
MLOutput out = mlCtx.execute(fullDMLScriptName, programArgs);
+ if(numRegisteredOutputs >= 2) {
+ String configStr = ConfigurationManager.getConfig().getConfigInfo();
+ if(configStr.contains("cp.parallel.matrixmult: true"))
+ Assert.fail("Configuration not updated via setConfig");
+ }
+
if(numRegisteredOutputs >= 1) {
JavaRDD<String> hOut = out.getStringRDD("H", "text");
String fName = output("h");