You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/02 18:25:17 UTC
incubator-systemml git commit: [HOTFIX] Allows multiple MLContext to
set the configuration property
Repository: incubator-systemml
Updated Branches:
refs/heads/master 7989ab4f3 -> 8324b69f1
[HOTFIX] Allows multiple MLContext to set the configuration property
- Also, added bugfix in mllearn to enable force GPU option.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/8324b69f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/8324b69f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/8324b69f
Branch: refs/heads/master
Commit: 8324b69f11fb71890e0b592603e759c68f4db87f
Parents: 7989ab4
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue May 2 10:25:01 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue May 2 11:25:01 2017 -0700
----------------------------------------------------------------------
.../sysml/api/mlcontext/ScriptExecutor.java | 35 +++++++++++++++-----
.../sysml/api/ml/BaseSystemMLClassifier.scala | 4 ++-
2 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index ee710b6..56beef3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -248,8 +248,30 @@ public class ScriptExecutor {
if (symbolTable != null) {
executionContext.setVariables(symbolTable);
}
- oldStatistics = DMLScript.STATISTICS;
- DMLScript.STATISTICS = statistics;
+
+ }
+
+ /**
+ * Set the global flags (for example: statistics, gpu, etc).
+ */
+ protected void setGlobalFlags() {
+ oldStatistics = DMLScript.STATISTICS;
+ DMLScript.STATISTICS = statistics;
+ oldForceGPU = DMLScript.FORCE_ACCELERATOR;
+ DMLScript.FORCE_ACCELERATOR = forceGPU;
+ oldGPU = DMLScript.USE_ACCELERATOR;
+ DMLScript.USE_ACCELERATOR = gpu;
+ DMLScript.STATISTICS_COUNT = statisticsMaxHeavyHitters;
+ }
+
+ /**
+ * Reset the global flags (for example: statistics, gpu, etc) post-execution.
+ */
+ protected void resetGlobalFlags() {
+ DMLScript.STATISTICS = oldStatistics;
+ DMLScript.FORCE_ACCELERATOR = oldForceGPU;
+ DMLScript.USE_ACCELERATOR = oldGPU;
+ DMLScript.STATISTICS_COUNT = 10;
}
/**
@@ -327,6 +349,7 @@ public class ScriptExecutor {
script.setScriptExecutor(this);
// Set global variable indicating the script type
DMLScript.SCRIPT_TYPE = script.getScriptType();
+ setGlobalFlags();
}
/**
@@ -334,9 +357,7 @@ public class ScriptExecutor {
*/
protected void cleanupAfterExecution() {
restoreInputsInSymbolTable();
- DMLScript.USE_ACCELERATOR = oldGPU;
- DMLScript.FORCE_ACCELERATOR = oldForceGPU;
- DMLScript.STATISTICS = oldStatistics;
+ resetGlobalFlags();
}
/**
@@ -652,8 +673,6 @@ public class ScriptExecutor {
*/
public void setGPU(boolean enabled) {
this.gpu = enabled;
- oldGPU = DMLScript.USE_ACCELERATOR;
- DMLScript.USE_ACCELERATOR = gpu;
}
/**
@@ -663,8 +682,6 @@ public class ScriptExecutor {
*/
public void setForceGPU(boolean enabled) {
this.forceGPU = enabled;
- oldForceGPU = DMLScript.FORCE_ACCELERATOR;
- DMLScript.FORCE_ACCELERATOR = forceGPU;
}
}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index 2dbcc03..f0af799 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -81,7 +81,9 @@ trait BaseSystemMLEstimatorOrModel {
def setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1:Int):BaseSystemMLEstimatorOrModel = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this}
def setConfigProperty(key:String, value:String):BaseSystemMLEstimatorOrModel = { config.put(key, value); this}
def updateML(ml:MLContext):Unit = {
- ml.setGPU(enableGPU); ml.setExplain(explain); ml.setStatistics(statistics); config.map(x => ml.setConfigProperty(x._1, x._2))
+ ml.setGPU(enableGPU); ml.setForceGPU(forceGPU);
+ ml.setExplain(explain); ml.setStatistics(statistics); ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters);
+ config.map(x => ml.setConfigProperty(x._1, x._2))
}
def copyProperties(other:BaseSystemMLEstimatorOrModel):BaseSystemMLEstimatorOrModel = {
other.setGPU(enableGPU); other.setForceGPU(forceGPU);