Repository: incubator-systemml
Updated Branches:
  refs/heads/master 7989ab4f3 -> 8324b69f1


[HOTFIX] Allows multiple MLContext to set the configuration property

- Also, added bugfix in mllearn to enable force GPU option.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/8324b69f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/8324b69f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/8324b69f

Branch: refs/heads/master
Commit: 8324b69f11fb71890e0b592603e759c68f4db87f
Parents: 7989ab4
Author: Niketan Pansare <[email protected]>
Authored: Tue May 2 10:25:01 2017 -0800
Committer: Niketan Pansare <[email protected]>
Committed: Tue May 2 11:25:01 2017 -0700

----------------------------------------------------------------------
 .../sysml/api/mlcontext/ScriptExecutor.java     | 35 +++++++++++++++-----
 .../sysml/api/ml/BaseSystemMLClassifier.scala   |  4 ++-
 2 files changed, 29 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java 
b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
index ee710b6..56beef3 100644
--- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
+++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java
@@ -248,8 +248,30 @@ public class ScriptExecutor {
                if (symbolTable != null) {
                        executionContext.setVariables(symbolTable);
                }
-        oldStatistics = DMLScript.STATISTICS;
-        DMLScript.STATISTICS = statistics;
+    
+       }
+       
+       /**
+        * Set the global flags (for example: statistics, gpu, etc).
+        */
+       protected void setGlobalFlags() {
+               oldStatistics = DMLScript.STATISTICS;
+    DMLScript.STATISTICS = statistics;
+    oldForceGPU = DMLScript.FORCE_ACCELERATOR;
+    DMLScript.FORCE_ACCELERATOR = forceGPU;
+    oldGPU = DMLScript.USE_ACCELERATOR;
+    DMLScript.USE_ACCELERATOR = gpu;
+    DMLScript.STATISTICS_COUNT = statisticsMaxHeavyHitters;
+       }
+       
+       /**
+        * Reset the global flags (for example: statistics, gpu, etc) 
post-execution. 
+        */
+       protected void resetGlobalFlags() {
+               DMLScript.STATISTICS = oldStatistics;
+               DMLScript.FORCE_ACCELERATOR = oldForceGPU;
+               DMLScript.USE_ACCELERATOR = oldGPU;
+               DMLScript.STATISTICS_COUNT = 10;
        }
 
        /**
@@ -327,6 +349,7 @@ public class ScriptExecutor {
                script.setScriptExecutor(this);
                // Set global variable indicating the script type
                DMLScript.SCRIPT_TYPE = script.getScriptType();
+               setGlobalFlags();
        }
 
        /**
@@ -334,9 +357,7 @@ public class ScriptExecutor {
         */
        protected void cleanupAfterExecution() {
                restoreInputsInSymbolTable();
-               DMLScript.USE_ACCELERATOR = oldGPU;
-               DMLScript.FORCE_ACCELERATOR = oldForceGPU;
-               DMLScript.STATISTICS = oldStatistics;
+               resetGlobalFlags();
        }
 
        /**
@@ -652,8 +673,6 @@ public class ScriptExecutor {
         */
     public void setGPU(boolean enabled) {
         this.gpu = enabled;
-        oldGPU = DMLScript.USE_ACCELERATOR;
-        DMLScript.USE_ACCELERATOR = gpu;
     }
        
        /**
@@ -663,8 +682,6 @@ public class ScriptExecutor {
         */
     public void setForceGPU(boolean enabled) {
         this.forceGPU = enabled;
-        oldForceGPU = DMLScript.FORCE_ACCELERATOR;
-        DMLScript.FORCE_ACCELERATOR = forceGPU;
     }
 
 }

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
----------------------------------------------------------------------
diff --git 
a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala 
b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
index 2dbcc03..f0af799 100644
--- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
+++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala
@@ -81,7 +81,9 @@ trait BaseSystemMLEstimatorOrModel {
   def 
setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1:Int):BaseSystemMLEstimatorOrModel
 = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this}
   def setConfigProperty(key:String, value:String):BaseSystemMLEstimatorOrModel 
= { config.put(key, value); this}
   def updateML(ml:MLContext):Unit = {
-    ml.setGPU(enableGPU); ml.setExplain(explain); 
ml.setStatistics(statistics); config.map(x => ml.setConfigProperty(x._1, x._2))
+    ml.setGPU(enableGPU); ml.setForceGPU(forceGPU);
+    ml.setExplain(explain); ml.setStatistics(statistics); 
ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters); 
+    config.map(x => ml.setConfigProperty(x._1, x._2))
   }
   def 
copyProperties(other:BaseSystemMLEstimatorOrModel):BaseSystemMLEstimatorOrModel 
= {
     other.setGPU(enableGPU); other.setForceGPU(forceGPU);

Reply via email to