You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@systemml.apache.org by ni...@apache.org on 2017/05/02 18:25:17 UTC

incubator-systemml git commit: [HOTFIX] Allows multiple MLContext to set the configuration property

Repository: incubator-systemml
Updated Branches:
  refs/heads/master 7989ab4f3 -> 8324b69f1


[HOTFIX] Allows multiple MLContext to set the configuration property

- Also, added bugfix in mllearn to enable force GPU option.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/8324b69f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/8324b69f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/8324b69f

Branch: refs/heads/master
Commit: 8324b69f11fb71890e0b592603e759c68f4db87f
Parents: 7989ab4
Author: Niketan Pansare <np...@us.ibm.com>
Authored: Tue May 2 10:25:01 2017 -0800
Committer: Niketan Pansare <np...@us.ibm.com>
Committed: Tue May 2 11:25:01 2017 -0700

----------------------------------------------------------------------
 .../sysml/api/mlcontext/ScriptExecutor.java     | 35 +++++++++++++++-----
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  4 ++-
 2 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index ee710b6..56beef3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -248,8 +248,30 @@ public class ScriptExecutor {
 		if (symbolTable != null) {
 			executionContext.setVariables(symbolTable);
 		}
-        oldStatistics = DMLScript.STATISTICS;
-        DMLScript.STATISTICS = statistics;
+    
+	}
+	
+	/**
+	 * Set the global flags (for example: statistics, gpu, etc).
+	 */
+	protected void setGlobalFlags() {
+		oldStatistics = DMLScript.STATISTICS;
+    DMLScript.STATISTICS = statistics;
+    oldForceGPU = DMLScript.FORCE_ACCELERATOR;
+    DMLScript.FORCE_ACCELERATOR = forceGPU;
+    oldGPU = DMLScript.USE_ACCELERATOR;
+    DMLScript.USE_ACCELERATOR = gpu;
+    DMLScript.STATISTICS_COUNT = statisticsMaxHeavyHitters;
+	}
+	
+	/**
+	 * Reset the global flags (for example: statistics, gpu, etc) post-execution. 
+	 */
+	protected void resetGlobalFlags() {
+		DMLScript.STATISTICS = oldStatistics;
+		DMLScript.FORCE_ACCELERATOR = oldForceGPU;
+		DMLScript.USE_ACCELERATOR = oldGPU;
+		DMLScript.STATISTICS_COUNT = 10;
 	}
 
 	/**
@@ -327,6 +349,7 @@ public class ScriptExecutor {
 		script.setScriptExecutor(this);
 		// Set global variable indicating the script type
 		DMLScript.SCRIPT_TYPE = script.getScriptType();
+		setGlobalFlags();
 	}
 
 	/**
@@ -334,9 +357,7 @@ public class ScriptExecutor {
 	 */
 	protected void cleanupAfterExecution() {
 		restoreInputsInSymbolTable();
-		DMLScript.USE_ACCELERATOR = oldGPU;
-		DMLScript.FORCE_ACCELERATOR = oldForceGPU;
-		DMLScript.STATISTICS = oldStatistics;
+		resetGlobalFlags();
 	}
 
 	/**
@@ -652,8 +673,6 @@ public class ScriptExecutor {
 	 */
     public void setGPU(boolean enabled) {
         this.gpu = enabled;
-        oldGPU = DMLScript.USE_ACCELERATOR;
-        DMLScript.USE_ACCELERATOR = gpu;
     }
 	
 	/**
@@ -663,8 +682,6 @@ public class ScriptExecutor {
 	 */
     public void setForceGPU(boolean enabled) {
         this.forceGPU = enabled;
-        oldForceGPU = DMLScript.FORCE_ACCELERATOR;
-        DMLScript.FORCE_ACCELERATOR = forceGPU;
     }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index 2dbcc03..f0af799 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -81,7 +81,9 @@ trait BaseSystemMLEstimatorOrModel {
   def setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1:Int):BaseSystemMLEstimatorOrModel = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this}
   def setConfigProperty(key:String, value:String):BaseSystemMLEstimatorOrModel = { config.put(key, value); this}
   def updateML(ml:MLContext):Unit = {
-    ml.setGPU(enableGPU); ml.setExplain(explain); ml.setStatistics(statistics); config.map(x => ml.setConfigProperty(x._1, x._2))
+    ml.setGPU(enableGPU); ml.setForceGPU(forceGPU);
+    ml.setExplain(explain); ml.setStatistics(statistics); ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters); 
+    config.map(x => ml.setConfigProperty(x._1, x._2))
   }
   def copyProperties(other:BaseSystemMLEstimatorOrModel):BaseSystemMLEstimatorOrModel = {
     other.setGPU(enableGPU); other.setForceGPU(forceGPU);