Repository: incubator-systemml Updated Branches: refs/heads/master 7989ab4f3 -> 8324b69f1
[HOTFIX] Allows multiple MLContext to set the configuration property - Also, added bugfix in mllearn to enable force GPU option. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/8324b69f Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/8324b69f Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/8324b69f Branch: refs/heads/master Commit: 8324b69f11fb71890e0b592603e759c68f4db87f Parents: 7989ab4 Author: Niketan Pansare <[email protected]> Authored: Tue May 2 10:25:01 2017 -0800 Committer: Niketan Pansare <[email protected]> Committed: Tue May 2 11:25:01 2017 -0700 ---------------------------------------------------------------------- .../sysml/api/mlcontext/ScriptExecutor.java | 35 +++++++++++++++----- .../sysml/api/ml/BaseSystemMLClassifier.scala | 4 ++- 2 files changed, 29 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java index ee710b6..56beef3 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -248,8 +248,30 @@ public class ScriptExecutor { if (symbolTable != null) { executionContext.setVariables(symbolTable); } - oldStatistics = DMLScript.STATISTICS; - DMLScript.STATISTICS = statistics; + + } + + /** + * Set the global flags (for example: statistics, gpu, etc). + */ + protected void setGlobalFlags() { + oldStatistics = DMLScript.STATISTICS; + DMLScript.STATISTICS = statistics; + oldForceGPU = DMLScript.FORCE_ACCELERATOR; + DMLScript.FORCE_ACCELERATOR = forceGPU; + oldGPU = DMLScript.USE_ACCELERATOR; + DMLScript.USE_ACCELERATOR = gpu; + DMLScript.STATISTICS_COUNT = statisticsMaxHeavyHitters; + } + + /** + * Reset the global flags (for example: statistics, gpu, etc) post-execution. + */ + protected void resetGlobalFlags() { + DMLScript.STATISTICS = oldStatistics; + DMLScript.FORCE_ACCELERATOR = oldForceGPU; + DMLScript.USE_ACCELERATOR = oldGPU; + DMLScript.STATISTICS_COUNT = 10; } /** @@ -327,6 +349,7 @@ public class ScriptExecutor { script.setScriptExecutor(this); // Set global variable indicating the script type DMLScript.SCRIPT_TYPE = script.getScriptType(); + setGlobalFlags(); } /** @@ -334,9 +357,7 @@ public class ScriptExecutor { */ protected void cleanupAfterExecution() { restoreInputsInSymbolTable(); - DMLScript.USE_ACCELERATOR = oldGPU; - DMLScript.FORCE_ACCELERATOR = oldForceGPU; - DMLScript.STATISTICS = oldStatistics; + resetGlobalFlags(); } /** @@ -652,8 +673,6 @@ public class ScriptExecutor { */ public void setGPU(boolean enabled) { this.gpu = enabled; - oldGPU = DMLScript.USE_ACCELERATOR; - DMLScript.USE_ACCELERATOR = gpu; } /** @@ -663,8 +682,6 @@ public class ScriptExecutor { */ public void setForceGPU(boolean enabled) { this.forceGPU = enabled; - oldForceGPU = DMLScript.FORCE_ACCELERATOR; - DMLScript.FORCE_ACCELERATOR = forceGPU; } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/8324b69f/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala ---------------------------------------------------------------------- diff --git a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala index 2dbcc03..f0af799 100644 --- a/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala +++ b/src/main/scala/org/apache/sysml/api/ml/BaseSystemMLClassifier.scala @@ -81,7 +81,9 @@ trait BaseSystemMLEstimatorOrModel { def setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters1:Int):BaseSystemMLEstimatorOrModel = { statisticsMaxHeavyHitters = statisticsMaxHeavyHitters1; this} def setConfigProperty(key:String, value:String):BaseSystemMLEstimatorOrModel = { config.put(key, value); this} def updateML(ml:MLContext):Unit = { - ml.setGPU(enableGPU); ml.setExplain(explain); ml.setStatistics(statistics); config.map(x => ml.setConfigProperty(x._1, x._2)) + ml.setGPU(enableGPU); ml.setForceGPU(forceGPU); + ml.setExplain(explain); ml.setStatistics(statistics); ml.setStatisticsMaxHeavyHitters(statisticsMaxHeavyHitters); + config.map(x => ml.setConfigProperty(x._1, x._2)) } def copyProperties(other:BaseSystemMLEstimatorOrModel):BaseSystemMLEstimatorOrModel = { other.setGPU(enableGPU); other.setForceGPU(forceGPU);
