http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/Script.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/Script.java b/src/main/java/org/apache/sysml/api/mlcontext/Script.java new file mode 100644 index 0000000..65d3338 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/Script.java @@ -0,0 +1,652 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.cp.Data; + +import scala.Tuple2; +import scala.Tuple3; +import scala.collection.JavaConversions; + +/** + * A Script object encapsulates a DML or PYDML script. + * + */ +public class Script { + + /** + * The type of script ({@code ScriptType.DML} or {@code ScriptType.PYDML}). + */ + private ScriptType scriptType; + /** + * The script content. + */ + private String scriptString; + /** + * The optional name of the script. + */ + private String name; + /** + * All inputs (input parameters ($) and input variables). + */ + private Map<String, Object> inputs = new LinkedHashMap<String, Object>(); + /** + * The input parameters ($). + */ + private Map<String, Object> inputParameters = new LinkedHashMap<String, Object>(); + /** + * The input variables. + */ + private Set<String> inputVariables = new LinkedHashSet<String>(); + /** + * The input matrix metadata if present. + */ + private Map<String, MatrixMetadata> inputMatrixMetadata = new LinkedHashMap<String, MatrixMetadata>(); + /** + * The output variables. + */ + private Set<String> outputVariables = new LinkedHashSet<String>(); + /** + * The symbol table containing the data associated with variables. + */ + private LocalVariableMap symbolTable = new LocalVariableMap(); + /** + * The ScriptExecutor which is used to define the execution of the script. + */ + private ScriptExecutor scriptExecutor; + /** + * The results of the execution of the script. + */ + private MLResults results; + + /** + * Script constructor, which by default creates a DML script. + */ + public Script() { + scriptType = ScriptType.DML; + } + + /** + * Script constructor, specifying the type of script ({@code ScriptType.DML} + * or {@code ScriptType.PYDML}). + * + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + */ + public Script(ScriptType scriptType) { + this.scriptType = scriptType; + } + + /** + * Script constructor, specifying the script content. By default, the script + * type is DML. + * + * @param scriptString + * the script content as a string + */ + public Script(String scriptString) { + this.scriptString = scriptString; + this.scriptType = ScriptType.DML; + } + + /** + * Script constructor, specifying the script content and the type of script + * (DML or PYDML). + * + * @param scriptString + * the script content as a string + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + */ + public Script(String scriptString, ScriptType scriptType) { + this.scriptString = scriptString; + this.scriptType = scriptType; + } + + /** + * Obtain the script type. + * + * @return {@code ScriptType.DML} or {@code ScriptType.PYDML} + */ + public ScriptType getScriptType() { + return scriptType; + } + + /** + * Set the type of script (DML or PYDML). + * + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + */ + public void setScriptType(ScriptType scriptType) { + this.scriptType = scriptType; + } + + /** + * Obtain the script string. + * + * @return the script string + */ + public String getScriptString() { + return scriptString; + } + + /** + * Set the script string. + * + * @param scriptString + * the script string + * @return {@code this} Script object to allow chaining of methods + */ + public Script setScriptString(String scriptString) { + this.scriptString = scriptString; + return this; + } + + /** + * Obtain the input variable names as an unmodifiable set of strings. + * + * @return the input variable names + */ + public Set<String> getInputVariables() { + return Collections.unmodifiableSet(inputVariables); + } + + /** + * Obtain the output variable names as an unmodifiable set of strings. + * + * @return the output variable names + */ + public Set<String> getOutputVariables() { + return Collections.unmodifiableSet(outputVariables); + } + + /** + * Obtain the symbol table, which is essentially a + * {@code HashMap<String, Data>} representing variables and their values. + * + * @return the symbol table + */ + public LocalVariableMap getSymbolTable() { + return symbolTable; + } + + /** + * Obtain an unmodifiable map of all inputs (parameters ($) and variables). + * + * @return all inputs to the script + */ + public Map<String, Object> getInputs() { + return Collections.unmodifiableMap(inputs); + } + + /** + * Obtain an unmodifiable map of input matrix metadata. + * + * @return input matrix metadata + */ + public Map<String, MatrixMetadata> getInputMatrixMetadata() { + return Collections.unmodifiableMap(inputMatrixMetadata); + } + + /** + * Pass a map of inputs to the script. + * + * @param inputs + * map of inputs (parameters ($) and variables). + * @return {@code this} Script object to allow chaining of methods + */ + public Script in(Map<String, Object> inputs) { + for (Entry<String, Object> input : inputs.entrySet()) { + in(input.getKey(), input.getValue()); + } + + return this; + } + + /** + * Pass a Scala Map of inputs to the script. + * + * @param inputs + * Scala Map of inputs (parameters ($) and variables). + * @return {@code this} Script object to allow chaining of methods + */ + public Script in(scala.collection.Map<String, Object> inputs) { + Map<String, Object> javaMap = JavaConversions.mapAsJavaMap(inputs); + in(javaMap); + + return this; + } + + /** + * Pass a Scala Seq of inputs to the script. The inputs are either two-value + * or three-value tuples, where the first value is the variable name, the + * second value is the variable value, and the third optional value is the + * metadata. + * + * @param inputs + * Scala Seq of inputs (parameters ($) and variables). + * @return {@code this} Script object to allow chaining of methods + */ + public Script in(scala.collection.Seq<Object> inputs) { + List<Object> list = JavaConversions.asJavaList(inputs); + for (Object obj : list) { + if (obj instanceof Tuple3) { + @SuppressWarnings("unchecked") + Tuple3<String, Object, MatrixMetadata> t3 = (Tuple3<String, Object, MatrixMetadata>) obj; + in(t3._1(), t3._2(), t3._3()); + } else if (obj instanceof Tuple2) { + @SuppressWarnings("unchecked") + Tuple2<String, Object> t2 = (Tuple2<String, Object>) obj; + in(t2._1(), t2._2()); + } else { + throw new MLContextException("Only Tuples of 2 or 3 values are permitted"); + } + } + return this; + } + + /** + * Obtain an unmodifiable map of all input parameters ($). + * + * @return input parameters ($) + */ + public Map<String, Object> getInputParameters() { + return inputParameters; + } + + /** + * Register an input (parameter ($) or variable). + * + * @param name + * name of the input + * @param value + * value of the input + * @return {@code this} Script object to allow chaining of methods + */ + public Script in(String name, Object value) { + return in(name, value, null); + } + + /** + * Register an input (parameter ($) or variable) with optional matrix + * metadata. + * + * @param name + * name of the input + * @param value + * value of the input + * @param matrixMetadata + * optional matrix metadata + * @return {@code this} Script object to allow chaining of methods + */ + public Script in(String name, Object value, MatrixMetadata matrixMetadata) { + MLContextUtil.checkInputValueType(name, value); + if (inputs == null) { + inputs = new LinkedHashMap<String, Object>(); + } + inputs.put(name, value); + + if (name.startsWith("$")) { + MLContextUtil.checkInputParameterType(name, value); + if (inputParameters == null) { + inputParameters = new LinkedHashMap<String, Object>(); + } + inputParameters.put(name, value); + } else { + Data data = MLContextUtil.convertInputType(name, value, matrixMetadata); + if (data != null) { + symbolTable.put(name, data); + inputVariables.add(name); + if (data instanceof MatrixObject) { + if (matrixMetadata != null) { + inputMatrixMetadata.put(name, matrixMetadata); + } + } + } + } + return this; + } + + /** + * Register an output variable. + * + * @param outputName + * name of the output variable + * @return {@code this} Script object to allow chaining of methods + */ + public Script out(String outputName) { + outputVariables.add(outputName); + return this; + } + + /** + * Register output variables. + * + * @param outputNames + * names of the output variables + * @return {@code this} Script object to allow chaining of methods + */ + public Script out(String... outputNames) { + outputVariables.addAll(Arrays.asList(outputNames)); + return this; + } + + /** + * Clear the inputs, outputs, and symbol table. + */ + public void clearIOS() { + clearInputs(); + clearOutputs(); + clearSymbolTable(); + } + + /** + * Clear the inputs and outputs, but not the symbol table. + */ + public void clearIO() { + clearInputs(); + clearOutputs(); + } + + /** + * Clear the script string, inputs, outputs, and symbol table. + */ + public void clearAll() { + scriptString = null; + clearIOS(); + } + + /** + * Clear the inputs. + */ + public void clearInputs() { + inputs.clear(); + inputParameters.clear(); + inputVariables.clear(); + inputMatrixMetadata.clear(); + } + + /** + * Clear the outputs. + */ + public void clearOutputs() { + outputVariables.clear(); + } + + /** + * Clear the symbol table. + */ + public void clearSymbolTable() { + symbolTable.removeAll(); + } + + /** + * Obtain the results of the script execution. + * + * @return the results of the script execution. + */ + public MLResults results() { + return results; + } + + /** + * Obtain the results of the script execution. + * + * @return the results of the script execution. + */ + public MLResults getResults() { + return results; + } + + /** + * Set the results of the script execution. + * + * @param results + * the results of the script execution. + */ + public void setResults(MLResults results) { + this.results = results; + } + + /** + * Obtain the script executor used by this Script. + * + * @return the ScriptExecutor used by this Script. + */ + public ScriptExecutor getScriptExecutor() { + return scriptExecutor; + } + + /** + * Set the ScriptExecutor used by this Script. + * + * @param scriptExecutor + * the script executor + */ + public void setScriptExecutor(ScriptExecutor scriptExecutor) { + this.scriptExecutor = scriptExecutor; + } + + /** + * Is the script type DML? + * + * @return {@code true} if the script type is DML, {@code false} otherwise + */ + public boolean isDML() { + return scriptType.isDML(); + } + + /** + * Is the script type PYDML? + * + * @return {@code true} if the script type is PYDML, {@code false} otherwise + */ + public boolean isPYDML() { + return scriptType.isPYDML(); + } + + /** + * Generate the script execution string, which adds read/load/write/save + * statements to the beginning and end of the script to execute. + * + * @return the script execution string + */ + public String getScriptExecutionString() { + StringBuilder sb = new StringBuilder(); + + Set<String> ins = getInputVariables(); + for (String in : ins) { + Object inValue = getInputs().get(in); + sb.append(in); + if (isDML()) { + if (inValue instanceof String) { + String quotedString = MLContextUtil.quotedString((String) inValue); + sb.append(" = " + quotedString + ";\n"); + } else if (MLContextUtil.isBasicType(inValue)) { + sb.append(" = read('', data_type='scalar');\n"); + } else { + sb.append(" = read('');\n"); + } + } else if (isPYDML()) { + if (inValue instanceof String) { + String quotedString = MLContextUtil.quotedString((String) inValue); + sb.append(" = " + quotedString + "\n"); + } else if (MLContextUtil.isBasicType(inValue)) { + sb.append(" = load('', data_type='scalar')\n"); + } else { + sb.append(" = load('')\n"); + } + } + + } + + sb.append(getScriptString()); + if (!getScriptString().endsWith("\n")) { + sb.append("\n"); + } + + Set<String> outs = getOutputVariables(); + for (String out : outs) { + if (isDML()) { + sb.append("write("); + sb.append(out); + sb.append(", '');\n"); + } else if (isPYDML()) { + sb.append("save("); + sb.append(out); + sb.append(", '')\n"); + } + } + + return sb.toString(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + sb.append(MLContextUtil.displayInputs("Inputs", inputs)); + sb.append("\n"); + sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); + return sb.toString(); + } + + /** + * Display information about the script as a String. This consists of the + * script type, inputs, outputs, input parameters, input variables, output + * variables, the symbol table, the script string, and the script execution + * string. + * + * @return information about this script as a String + */ + public String info() { + StringBuilder sb = new StringBuilder(); + + sb.append("Script Type: "); + sb.append(scriptType); + sb.append("\n\n"); + sb.append(MLContextUtil.displayInputs("Inputs", inputs)); + sb.append("\n"); + sb.append(MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable)); + sb.append("\n"); + sb.append(MLContextUtil.displayMap("Input Parameters", inputParameters)); + sb.append("\n"); + sb.append(MLContextUtil.displaySet("Input Variables", inputVariables)); + sb.append("\n"); + sb.append(MLContextUtil.displaySet("Output Variables", outputVariables)); + sb.append("\n"); + sb.append(MLContextUtil.displaySymbolTable("Symbol Table", symbolTable)); + sb.append("\nScript String:\n"); + sb.append(scriptString); + sb.append("\nScript Execution String:\n"); + sb.append(getScriptExecutionString()); + sb.append("\n"); + + return sb.toString(); + } + + /** + * Display the script inputs. + * + * @return the script inputs + */ + public String displayInputs() { + return MLContextUtil.displayInputs("Inputs", inputs); + } + + /** + * Display the script outputs. + * + * @return the script outputs as a String + */ + public String displayOutputs() { + return MLContextUtil.displayOutputs("Outputs", outputVariables, symbolTable); + } + + /** + * Display the script input parameters. + * + * @return the script input parameters as a String + */ + public String displayInputParameters() { + return MLContextUtil.displayMap("Input Parameters", inputParameters); + } + + /** + * Display the script input variables. + * + * @return the script input variables as a String + */ + public String displayInputVariables() { + return MLContextUtil.displaySet("Input Variables", inputVariables); + } + + /** + * Display the script output variables. + * + * @return the script output variables as a String + */ + public String displayOutputVariables() { + return MLContextUtil.displaySet("Output Variables", outputVariables); + } + + /** + * Display the script symbol table. + * + * @return the script symbol table as a String + */ + public String displaySymbolTable() { + return MLContextUtil.displaySymbolTable("Symbol Table", symbolTable); + } + + /** + * Obtain the script name. + * + * @return the script name + */ + public String getName() { + return name; + } + + /** + * Set the script name. + * + * @param name + * the script name + * @return {@code this} Script object to allow chaining of methods + */ + public Script setName(String name) { + this.name = name; + return this; + } + +}
http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java new file mode 100644 index 0000000..4702af2 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptExecutor.java @@ -0,0 +1,624 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.io.IOException; +import java.util.Map; +import java.util.Set; + +import org.apache.commons.lang3.StringUtils; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.jmlc.JMLCUtils; +import org.apache.sysml.api.monitoring.SparkMonitoringUtil; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.conf.DMLConfig; +import org.apache.sysml.hops.HopsException; +import org.apache.sysml.hops.OptimizerUtils; +import org.apache.sysml.hops.OptimizerUtils.OptimizationLevel; +import org.apache.sysml.hops.globalopt.GlobalOptimizerWrapper; +import org.apache.sysml.hops.rewrite.ProgramRewriter; +import org.apache.sysml.hops.rewrite.RewriteRemovePersistentReadWrite; +import org.apache.sysml.lops.LopsException; +import org.apache.sysml.parser.AParserWrapper; +import org.apache.sysml.parser.DMLProgram; +import org.apache.sysml.parser.DMLTranslator; +import org.apache.sysml.parser.LanguageException; +import org.apache.sysml.parser.ParseException; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.Program; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysml.utils.Explain; +import org.apache.sysml.utils.Explain.ExplainCounts; +import org.apache.sysml.utils.Statistics; + +/** + * ScriptExecutor executes a DML or PYDML Script object using SystemML. This is + * accomplished by calling the {@link #execute} method. + * <p> + * Script execution via the MLContext API typically consists of the following + * steps: + * </p> + * <ol> + * <li>Language Steps + * <ol> + * <li>Parse script into program</li> + * <li>Live variable analysis</li> + * <li>Validate program</li> + * </ol> + * </li> + * <li>HOP (High-Level Operator) Steps + * <ol> + * <li>Construct HOP DAGs</li> + * <li>Static rewrites</li> + * <li>Intra-/Inter-procedural analysis</li> + * <li>Dynamic rewrites</li> + * <li>Compute memory estimates</li> + * <li>Rewrite persistent reads and writes (MLContext-specific)</li> + * </ol> + * </li> + * <li>LOP (Low-Level Operator) Steps + * <ol> + * <li>Contruct LOP DAGs</li> + * <li>Generate runtime program</li> + * <li>Execute runtime program</li> + * <li>Dynamic recompilation</li> + * </ol> + * </li> + * </ol> + * <p> + * Modifications to these steps can be accomplished by subclassing + * ScriptExecutor. For example, the following code will turn off the global data + * flow optimization check by subclassing ScriptExecutor and overriding the + * globalDataFlowOptimization method. + * </p> + * + * <code>ScriptExecutor scriptExecutor = new ScriptExecutor() { + * <br> // turn off global data flow optimization check + * <br> @Override + * <br> protected void globalDataFlowOptimization() { + * <br> return; + * <br> } + * <br>}; + * <br>ml.execute(script, scriptExecutor);</code> + * <p> + * + * For more information, please see the {@link #execute} method. + */ +public class ScriptExecutor { + + protected DMLConfig config; + protected SparkMonitoringUtil sparkMonitoringUtil; + protected DMLProgram dmlProgram; + protected DMLTranslator dmlTranslator; + protected Program runtimeProgram; + protected ExecutionContext executionContext; + protected Script script; + protected boolean explain = false; + protected boolean statistics = false; + + /** + * ScriptExecutor constructor. + */ + public ScriptExecutor() { + config = ConfigurationManager.getDMLConfig(); + } + + /** + * ScriptExecutor constructor, where the configuration properties are passed + * in. + * + * @param config + * the configuration properties to use by the ScriptExecutor + */ + public ScriptExecutor(DMLConfig config) { + this.config = config; + ConfigurationManager.setGlobalConfig(config); + } + + /** + * ScriptExecutor constructor, where a SparkMonitoringUtil object is passed + * in. + * + * @param sparkMonitoringUtil + * SparkMonitoringUtil object to monitor Spark + */ + public ScriptExecutor(SparkMonitoringUtil sparkMonitoringUtil) { + this(); + this.sparkMonitoringUtil = sparkMonitoringUtil; + } + + /** + * ScriptExecutor constructor, where the configuration properties and a + * SparkMonitoringUtil object are passed in. + * + * @param config + * the configuration properties to use by the ScriptExecutor + * @param sparkMonitoringUtil + * SparkMonitoringUtil object to monitor Spark + */ + public ScriptExecutor(DMLConfig config, SparkMonitoringUtil sparkMonitoringUtil) { + this.config = config; + this.sparkMonitoringUtil = sparkMonitoringUtil; + } + + /** + * Construct DAGs of high-level operators (HOPs) for each block of + * statements. + */ + protected void constructHops() { + try { + dmlTranslator.constructHops(dmlProgram); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e); + } catch (ParseException e) { + throw new MLContextException("Exception occurred while constructing HOPS (high-level operators)", e); + } + } + + /** + * Apply static rewrites, perform intra-/inter-procedural analysis to + * propagate size information into functions, apply dynamic rewrites, and + * compute memory estimates for all HOPs. + */ + protected void rewriteHops() { + try { + dmlTranslator.rewriteHopsDAG(dmlProgram); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); + } catch (HopsException e) { + throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); + } catch (ParseException e) { + throw new MLContextException("Exception occurred while rewriting HOPS (high-level operators)", e); + } + } + + /** + * Output a description of the program to standard output. + */ + protected void showExplanation() { + if (explain) { + try { + System.out.println(Explain.explain(dmlProgram)); + } catch (HopsException e) { + throw new MLContextException("Exception occurred while explaining dml program", e); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred while explaining dml program", e); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while explaining dml program", e); + } + } + } + + /** + * Construct DAGs of low-level operators (LOPs) based on the DAGs of + * high-level operators (HOPs). + */ + protected void constructLops() { + try { + dmlTranslator.constructLops(dmlProgram); + } catch (ParseException e) { + throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); + } catch (HopsException e) { + throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); + } catch (LopsException e) { + throw new MLContextException("Exception occurred while constructing LOPS (low-level operators)", e); + } + } + + /** + * Create runtime program. For each namespace, translate function statement + * blocks into function program blocks and add these to the runtime program. + * For each top-level block, add the program block to the runtime program. + */ + protected void generateRuntimeProgram() { + try { + runtimeProgram = dmlProgram.getRuntimeProgram(config); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while generating runtime program", e); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred while generating runtime program", e); + } catch (LopsException e) { + throw new MLContextException("Exception occurred while generating runtime program", e); + } catch (IOException e) { + throw new MLContextException("Exception occurred while generating runtime program", e); + } + } + + /** + * Count the number of compiled MR Jobs/Spark Instructions in the runtime + * program and set this value in the statistics. + */ + protected void countCompiledMRJobsAndSparkInstructions() { + ExplainCounts counts = Explain.countDistributedOperations(runtimeProgram); + Statistics.resetNoOfCompiledJobs(counts.numJobs); + } + + /** + * Create an execution context and set its variables to be the symbol table + * of the script. + */ + protected void createAndInitializeExecutionContext() { + executionContext = ExecutionContextFactory.createContext(runtimeProgram); + LocalVariableMap symbolTable = script.getSymbolTable(); + if (symbolTable != null) { + executionContext.setVariables(symbolTable); + } + } + + /** + * Execute a DML or PYDML script. This is broken down into the following + * primary methods: + * + * <ol> + * <li>{@link #parseScript()}</li> + * <li>{@link #liveVariableAnalysis()}</li> + * <li>{@link #validateScript()}</li> + * <li>{@link #constructHops()}</li> + * <li>{@link #rewriteHops()}</li> + * <li>{@link #showExplanation()}</li> + * <li>{@link #rewritePersistentReadsAndWrites()}</li> + * <li>{@link #constructLops()}</li> + * <li>{@link #generateRuntimeProgram()}</li> + * <li>{@link #globalDataFlowOptimization()}</li> + * <li>{@link #countCompiledMRJobsAndSparkInstructions()}</li> + * <li>{@link #initializeCachingAndScratchSpace()}</li> + * <li>{@link #cleanupRuntimeProgram()}</li> + * <li>{@link #createAndInitializeExecutionContext()}</li> + * <li>{@link #executeRuntimeProgram()}</li> + * <li>{@link #cleanupAfterExecution()}</li> + * </ol> + * + * @param script + * the DML or PYDML script to execute + */ + public MLResults execute(Script script) { + this.script = script; + checkScriptHasTypeAndString(); + script.setScriptExecutor(this); + setScriptStringInSparkMonitor(); + + // main steps in script execution + parseScript(); + liveVariableAnalysis(); + validateScript(); + constructHops(); + rewriteHops(); + showExplanation(); + rewritePersistentReadsAndWrites(); + constructLops(); + generateRuntimeProgram(); + globalDataFlowOptimization(); + countCompiledMRJobsAndSparkInstructions(); + initializeCachingAndScratchSpace(); + cleanupRuntimeProgram(); + createAndInitializeExecutionContext(); + executeRuntimeProgram(); + setExplainRuntimeProgramInSparkMonitor(); + cleanupAfterExecution(); + + // add symbol table to MLResults + MLResults mlResults = new MLResults(script); + script.setResults(mlResults); + + if (statistics) { + System.out.println(Statistics.display()); + } + + return mlResults; + } + + /** + * Perform any necessary cleanup operations after program execution. + */ + protected void cleanupAfterExecution() { + restoreInputsInSymbolTable(); + } + + /** + * Restore the input variables in the symbol table after script execution. + */ + protected void restoreInputsInSymbolTable() { + Map<String, Object> inputs = script.getInputs(); + Map<String, MatrixMetadata> inputMatrixMetadata = script.getInputMatrixMetadata(); + LocalVariableMap symbolTable = script.getSymbolTable(); + Set<String> inputVariables = script.getInputVariables(); + for (String inputVariable : inputVariables) { + if (symbolTable.get(inputVariable) == null) { + // retrieve optional metadata if it exists + MatrixMetadata mm = inputMatrixMetadata.get(inputVariable); + script.in(inputVariable, inputs.get(inputVariable), mm); + } + } + } + + /** + * Remove rmvar instructions so as to maintain registered outputs after the + * program terminates. + */ + protected void cleanupRuntimeProgram() { + JMLCUtils.cleanupRuntimeProgram(runtimeProgram, (script.getOutputVariables() == null) ? new String[0] : script + .getOutputVariables().toArray(new String[0])); + } + + /** + * Execute the runtime program. This involves execution of the program + * blocks that make up the runtime program and may involve dynamic + * recompilation. + */ + protected void executeRuntimeProgram() { + try { + runtimeProgram.execute(executionContext); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred while executing runtime program", e); + } + } + + /** + * Obtain the SparkMonitoringUtil object. + * + * @return the SparkMonitoringUtil object, if available + */ + public SparkMonitoringUtil getSparkMonitoringUtil() { + return sparkMonitoringUtil; + } + + /** + * Check security, create scratch space, cleanup working directories, + * initialize caching, and reset statistics. + */ + protected void initializeCachingAndScratchSpace() { + try { + DMLScript.initHadoopExecution(config); + } catch (ParseException e) { + throw new MLContextException("Exception occurred initializing caching and scratch space", e); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred initializing caching and scratch space", e); + } catch (IOException e) { + throw new MLContextException("Exception occurred initializing caching and scratch space", e); + } + } + + /** + * Optimize the program. + */ + protected void globalDataFlowOptimization() { + if (OptimizerUtils.isOptLevel(OptimizationLevel.O4_GLOBAL_TIME_MEMORY)) { + try { + runtimeProgram = GlobalOptimizerWrapper.optimizeProgram(dmlProgram, runtimeProgram); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred during global data flow optimization", e); + } catch (HopsException e) { + throw new MLContextException("Exception occurred during global data flow optimization", e); + } catch (LopsException e) { + throw new MLContextException("Exception occurred during global data flow optimization", e); + } + } + } + + /** + * Parse the script into an ANTLR parse tree, and convert this parse tree + * into a SystemML program. Parsing includes lexical/syntactic analysis. + */ + protected void parseScript() { + try { + AParserWrapper parser = AParserWrapper.createParser(script.getScriptType().isPYDML()); + Map<String, Object> inputParameters = script.getInputParameters(); + Map<String, String> inputParametersStringMaps = MLContextUtil.convertInputParametersForParser( + inputParameters, script.getScriptType()); + + String scriptExecutionString = script.getScriptExecutionString(); + dmlProgram = parser.parse(null, scriptExecutionString, inputParametersStringMaps); + } catch (ParseException e) { + throw new MLContextException("Exception occurred while parsing script", e); + } + } + + /** + * Replace persistent reads and writes with transient reads and writes in + * the symbol table. + */ + protected void rewritePersistentReadsAndWrites() { + LocalVariableMap symbolTable = script.getSymbolTable(); + if (symbolTable != null) { + String[] inputs = (script.getInputVariables() == null) ? new String[0] : script.getInputVariables() + .toArray(new String[0]); + String[] outputs = (script.getOutputVariables() == null) ? new String[0] : script.getOutputVariables() + .toArray(new String[0]); + RewriteRemovePersistentReadWrite rewrite = new RewriteRemovePersistentReadWrite(inputs, outputs); + ProgramRewriter programRewriter = new ProgramRewriter(rewrite); + try { + programRewriter.rewriteProgramHopDAGs(dmlProgram); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e); + } catch (HopsException e) { + throw new MLContextException("Exception occurred while rewriting persistent reads and writes", e); + } + } + + } + + /** + * Set the SystemML configuration properties. + * + * @param config + * The configuration properties + */ + public void setConfig(DMLConfig config) { + this.config = config; + ConfigurationManager.setGlobalConfig(config); + } + + /** + * Set the explanation of the runtime program in the SparkMonitoringUtil if + * it exists. + */ + protected void setExplainRuntimeProgramInSparkMonitor() { + if (sparkMonitoringUtil != null) { + try { + String explainOutput = Explain.explain(runtimeProgram); + sparkMonitoringUtil.setExplainOutput(explainOutput); + } catch (HopsException e) { + throw new MLContextException("Exception occurred while explaining runtime program", e); + } + } + + } + + /** + * Set the script string in the SparkMonitoringUtil if it exists. + */ + protected void setScriptStringInSparkMonitor() { + if (sparkMonitoringUtil != null) { + sparkMonitoringUtil.setDMLString(script.getScriptString()); + } + } + + /** + * Set the SparkMonitoringUtil object. + * + * @param sparkMonitoringUtil + * The SparkMonitoringUtil object + */ + public void setSparkMonitoringUtil(SparkMonitoringUtil sparkMonitoringUtil) { + this.sparkMonitoringUtil = sparkMonitoringUtil; + } + + /** + * Liveness analysis is performed on the program, obtaining sets of live-in + * and live-out variables by forward and backward passes over the program. + */ + protected void liveVariableAnalysis() { + try { + dmlTranslator = new DMLTranslator(dmlProgram); + dmlTranslator.liveVariableAnalysis(dmlProgram); + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception occurred during live variable analysis", e); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred during live variable analysis", e); + } + } + + /** + * Semantically validate the program's expressions, statements, and + * statement blocks in a single recursive pass over the program. Constant + * and size propagation occurs during this step. + */ + protected void validateScript() { + try { + dmlTranslator.validateParseTree(dmlProgram); + } catch (LanguageException e) { + throw new MLContextException("Exception occurred while validating script", e); + } catch (ParseException e) { + throw new MLContextException("Exception occurred while validating script", e); + } catch (IOException e) { + throw new MLContextException("Exception occurred while validating script", e); + } + } + + /** + * Check that the Script object has a type (DML or PYDML) and a string + * representing the content of the Script. + */ + protected void checkScriptHasTypeAndString() { + if (script == null) { + throw new MLContextException("Script is null"); + } else if (script.getScriptType() == null) { + throw new MLContextException("ScriptType (DML or PYDML) needs to be specified"); + } else if (script.getScriptString() == null) { + throw new MLContextException("Script string is null"); + } else if (StringUtils.isBlank(script.getScriptString())) { + throw new MLContextException("Script string is blank"); + } + } + + /** + * Obtain the program + * + * @return the program + */ + public DMLProgram getDmlProgram() { + return dmlProgram; + } + + /** + * Obtain the translator + * + * @return the translator + */ + public DMLTranslator getDmlTranslator() { + return dmlTranslator; + } + + /** + * Obtain the runtime program + * + * @return the runtime program + */ + public Program getRuntimeProgram() { + return runtimeProgram; + } + + /** + * Obtain the execution context + * + * @return the execution context + */ + public ExecutionContext getExecutionContext() { + return executionContext; + } + + /** + * Obtain the Script object associated with this ScriptExecutor + * + * @return the Script object associated with this ScriptExecutor + */ + public Script getScript() { + return script; + } + + /** + * Whether or not an explanation of the DML/PYDML program should be output + * to standard output. + * + * @param explain + * {@code true} if explanation should be output, {@code false} + * otherwise + */ + public void setExplain(boolean explain) { + this.explain = explain; + } + + /** + * Whether or not statistics about the DML/PYDML program should be output to + * standard output. + * + * @param statistics + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public void setStatistics(boolean statistics) { + this.statistics = statistics; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java new file mode 100644 index 0000000..5f0e56b --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptFactory.java @@ -0,0 +1,422 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.net.MalformedURLException; +import java.net.URL; + +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; +import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.runtime.util.LocalFileUtils; + +/** + * Factory for creating DML and PYDML Script objects from strings, files, URLs, + * and input streams. + * + */ +public class ScriptFactory { + + /** + * Create a DML Script object based on a string path to a file. + * + * @param scriptFilePath + * path to DML script file (local or HDFS) + * @return DML Script object + */ + public static Script dmlFromFile(String scriptFilePath) { + return scriptFromFile(scriptFilePath, ScriptType.DML); + } + + /** + * Create a DML Script object based on an input stream. + * + * @param inputStream + * input stream to DML + * @return DML Script object + */ + public static Script dmlFromInputStream(InputStream inputStream) { + return scriptFromInputStream(inputStream, ScriptType.DML); + } + + /** + * Creates a DML Script object based on a file in the local file system. To + * create a DML Script object from a local file or HDFS, please use + * {@link #dmlFromFile(String)}. + * + * @param localScriptFile + * the local DML file + * @return DML Script object + */ + public static Script dmlFromLocalFile(File localScriptFile) { + return scriptFromLocalFile(localScriptFile, ScriptType.DML); + } + + /** + * Create a DML Script object based on a string. + * + * @param scriptString + * string of DML + * @return DML Script object + */ + public static Script dmlFromString(String scriptString) { + return scriptFromString(scriptString, ScriptType.DML); + } + + /** + * Create a DML Script object based on a URL path. + * + * @param scriptUrlPath + * URL path to DML script + * @return DML Script object + */ + public static Script dmlFromUrl(String scriptUrlPath) { + return scriptFromUrl(scriptUrlPath, ScriptType.DML); + } + + /** + * Create a DML Script object based on a URL. + * + * @param scriptUrl + * URL to DML script + * @return DML Script object + */ + public static Script dmlFromUrl(URL scriptUrl) { + return scriptFromUrl(scriptUrl, ScriptType.DML); + } + + /** + * Create a PYDML Script object based on a string path to a file. + * + * @param scriptFilePath + * path to PYDML script file (local or HDFS) + * @return PYDML Script object + */ + public static Script pydmlFromFile(String scriptFilePath) { + return scriptFromFile(scriptFilePath, ScriptType.PYDML); + } + + /** + * Create a PYDML Script object based on an input stream. + * + * @param inputStream + * input stream to PYDML + * @return PYDML Script object + */ + public static Script pydmlFromInputStream(InputStream inputStream) { + return scriptFromInputStream(inputStream, ScriptType.PYDML); + } + + /** + * Creates a PYDML Script object based on a file in the local file system. + * To create a PYDML Script object from a local file or HDFS, please use + * {@link #pydmlFromFile(String)}. + * + * @param localScriptFile + * the local PYDML file + * @return PYDML Script object + */ + public static Script pydmlFromLocalFile(File localScriptFile) { + return scriptFromLocalFile(localScriptFile, ScriptType.PYDML); + } + + /** + * Create a PYDML Script object based on a string. + * + * @param scriptString + * string of PYDML + * @return PYDML Script object + */ + public static Script pydmlFromString(String scriptString) { + return scriptFromString(scriptString, ScriptType.PYDML); + } + + /** + * Creat a PYDML Script object based on a URL path. + * + * @param scriptUrlPath + * URL path to PYDML script + * @return PYDML Script object + */ + public static Script pydmlFromUrl(String scriptUrlPath) { + return scriptFromUrl(scriptUrlPath, ScriptType.PYDML); + } + + /** + * Create a PYDML Script object based on a URL. + * + * @param scriptUrl + * URL to PYDML script + * @return PYDML Script object + */ + public static Script pydmlFromUrl(URL scriptUrl) { + return scriptFromUrl(scriptUrl, ScriptType.PYDML); + } + + /** + * Create a DML or PYDML Script object based on a string path to a file. + * + * @param scriptFilePath + * path to DML or PYDML script file (local or HDFS) + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromFile(String scriptFilePath, ScriptType scriptType) { + String scriptString = getScriptStringFromFile(scriptFilePath); + return scriptFromString(scriptString, scriptType).setName(scriptFilePath); + } + + /** + * Create a DML or PYDML Script object based on an input stream. + * + * @param inputStream + * input stream to DML or PYDML + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromInputStream(InputStream inputStream, ScriptType scriptType) { + String scriptString = getScriptStringFromInputStream(inputStream); + return scriptFromString(scriptString, scriptType); + } + + /** + * Creates a DML or PYDML Script object based on a file in the local file + * system. To create a Script object from a local file or HDFS, please use + * {@link scriptFromFile(String, ScriptType)}. + * + * @param localScriptFile + * The local DML or PYDML file + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromLocalFile(File localScriptFile, ScriptType scriptType) { + String scriptString = getScriptStringFromFile(localScriptFile); + return scriptFromString(scriptString, scriptType).setName(localScriptFile.getName()); + } + + /** + * Create a DML or PYDML Script object based on a string. + * + * @param scriptString + * string of DML or PYDML + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromString(String scriptString, ScriptType scriptType) { + Script script = new Script(scriptString, scriptType); + return script; + } + + /** + * Creat a DML or PYDML Script object based on a URL path. + * + * @param scriptUrlPath + * URL path to DML or PYDML script + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromUrl(String scriptUrlPath, ScriptType scriptType) { + String scriptString = getScriptStringFromUrl(scriptUrlPath); + return scriptFromString(scriptString, scriptType).setName(scriptUrlPath); + } + + /** + * Create a DML or PYDML Script object based on a URL. + * + * @param scriptUrl + * URL to DML or PYDML script + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return DML or PYDML Script object + */ + private static Script scriptFromUrl(URL scriptUrl, ScriptType scriptType) { + String scriptString = getScriptStringFromUrl(scriptUrl); + return scriptFromString(scriptString, scriptType).setName(scriptUrl.toString()); + } + + /** + * Create a DML Script object based on a string. + * + * @param scriptString + * string of DML + * @return DML Script object + */ + public static Script dml(String scriptString) { + return dmlFromString(scriptString); + } + + /** + * Obtain a script string from a file in the local file system. To obtain a + * script string from a file in HDFS, please use + * getScriptStringFromFile(String scriptFilePath). + * + * @param file + * The script file. + * @return The script string. + * @throws MLContextException + * If a problem occurs reading the script string from the file. + */ + private static String getScriptStringFromFile(File file) { + if (file == null) { + throw new MLContextException("Script file is null"); + } + String filePath = file.getPath(); + try { + if (!LocalFileUtils.validateExternalFilename(filePath, false)) { + throw new MLContextException("Invalid (non-trustworthy) local filename: " + filePath); + } + String scriptString = FileUtils.readFileToString(file); + return scriptString; + } catch (IllegalArgumentException e) { + throw new MLContextException("Error trying to read script string from file: " + filePath, e); + } catch (IOException e) { + throw new MLContextException("Error trying to read script string from file: " + filePath, e); + } + } + + /** + * Obtain a script string from a file. + * + * @param scriptFilePath + * The file path to the script file (either local file system or + * HDFS) + * @return The script string + * @throws MLContextException + * If a problem occurs reading the script string from the file + */ + private static String getScriptStringFromFile(String scriptFilePath) { + if (scriptFilePath == null) { + throw new MLContextException("Script file path is null"); + } + try { + if (scriptFilePath.startsWith("hdfs:") || scriptFilePath.startsWith("gpfs:")) { + if (!LocalFileUtils.validateExternalFilename(scriptFilePath, true)) { + throw new MLContextException("Invalid (non-trustworthy) hdfs/gpfs filename: " + scriptFilePath); + } + FileSystem fs = FileSystem.get(ConfigurationManager.getCachedJobConf()); + Path path = new Path(scriptFilePath); + FSDataInputStream fsdis = fs.open(path); + String scriptString = IOUtils.toString(fsdis); + return scriptString; + } else {// from local file system + if (!LocalFileUtils.validateExternalFilename(scriptFilePath, false)) { + throw new MLContextException("Invalid (non-trustworthy) local filename: " + scriptFilePath); + } + File scriptFile = new File(scriptFilePath); + String scriptString = FileUtils.readFileToString(scriptFile); + return scriptString; + } + } catch (IllegalArgumentException e) { + throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e); + } catch (IOException e) { + throw new MLContextException("Error trying to read script string from file: " + scriptFilePath, e); + } + } + + /** + * Obtain a script string from an InputStream. + * + * @param inputStream + * The InputStream from which to read the script string + * @return The script string + * @throws MLContextException + * If a problem occurs reading the script string from the URL + */ + private static String getScriptStringFromInputStream(InputStream inputStream) { + if (inputStream == null) { + throw new MLContextException("InputStream is null"); + } + try { + String scriptString = IOUtils.toString(inputStream); + return scriptString; + } catch (IOException e) { + throw new MLContextException("Error trying to read script string from InputStream", e); + } + } + + /** + * Obtain a script string from a URL. + * + * @param scriptUrlPath + * The URL path to the script file + * @return The script string + * @throws MLContextException + * If a problem occurs reading the script string from the URL + */ + private static String getScriptStringFromUrl(String scriptUrlPath) { + if (scriptUrlPath == null) { + throw new MLContextException("Script URL path is null"); + } + try { + URL url = new URL(scriptUrlPath); + return getScriptStringFromUrl(url); + } catch (MalformedURLException e) { + throw new MLContextException("Error trying to read script string from URL path: " + scriptUrlPath, e); + } + } + + /** + * Obtain a script string from a URL. + * + * @param url + * The script URL + * @return The script string + * @throws MLContextException + * If a problem occurs reading the script string from the URL + */ + private static String getScriptStringFromUrl(URL url) { + if (url == null) { + throw new MLContextException("URL is null"); + } + String urlString = url.toString(); + if ((!urlString.toLowerCase().startsWith("http:")) && (!urlString.toLowerCase().startsWith("https:"))) { + throw new MLContextException("Currently only reading from http and https URLs is supported"); + } + try { + InputStream is = url.openStream(); + String scriptString = IOUtils.toString(is); + return scriptString; + } catch (IOException e) { + throw new MLContextException("Error trying to read script string from URL: " + url, e); + } + } + + /** + * Create a PYDML script object based on a string. + * + * @param scriptString + * string of PYDML + * @return PYDML Script object + */ + public static Script pydml(String scriptString) { + return pydmlFromString(scriptString); + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java new file mode 100644 index 0000000..94c9057 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/ScriptType.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +/** + * ScriptType represents the type of script, DML (R-like syntax) or PYDML + * (Python-like syntax). + * + */ +public enum ScriptType { + /** + * R-like syntax. + */ + DML, + + /** + * Python-like syntax. + */ + PYDML; + + /** + * Obtain script type as a lowercase string ("dml" or "pydml"). + * + * @return lowercase string representing the script type + */ + public String lowerCase() { + return super.toString().toLowerCase(); + } + + /** + * Is the script type DML? + * + * @return {@code true} if the script type is DML, {@code false} otherwise + */ + public boolean isDML() { + return (this == ScriptType.DML); + } + + /** + * Is the script type PYDML? + * + * @return {@code true} if the script type is PYDML, {@code false} otherwise + */ + public boolean isPYDML() { + return (this == ScriptType.PYDML); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 0eea221..c715331 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -36,11 +36,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.storage.RDDInfo; import org.apache.spark.storage.StorageLevel; - -import scala.Tuple2; - import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.MLContext; import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; @@ -82,6 +78,8 @@ import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.runtime.util.UtilFunctions; import org.apache.sysml.utils.Statistics; +import scala.Tuple2; + public class SparkExecutionContext extends ExecutionContext { @@ -178,22 +176,28 @@ public class SparkExecutionContext extends ExecutionContext * */ private synchronized static void initSparkContext() - { + { //check for redundant spark context init if( _spctx != null ) return; - + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; //create a default spark context (master, appname, etc refer to system properties //as given in the spark configuration or during spark-submit) - MLContext mlCtx = MLContextProxy.getActiveMLContext(); - if(mlCtx != null) + Object mlCtxObj = MLContextProxy.getActiveMLContext(); + if(mlCtxObj != null) { // This is when DML is called through spark shell // Will clean the passing of static variables later as this involves minimal change to DMLScript - _spctx = new JavaSparkContext(mlCtx.getSparkContext()); + if (mlCtxObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj; + _spctx = new JavaSparkContext(mlCtx.getSparkContext()); + } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj; + _spctx = new JavaSparkContext(mlCtx.getSparkContext()); + } } else { @@ -1424,11 +1428,26 @@ public class SparkExecutionContext extends ExecutionContext } } - MLContext mlContext = MLContextProxy.getActiveMLContext(); - if(mlContext != null && mlContext.getMonitoringUtil() != null) { - mlContext.getMonitoringUtil().setLineageInfo(inst, outDebugString); - } - else { + + Object mlContextObj = MLContextProxy.getActiveMLContext(); + if (mlContextObj != null) { + if (mlContextObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlContextObj; + if (mlCtx.getMonitoringUtil() != null) { + mlCtx.getMonitoringUtil().setLineageInfo(inst, outDebugString); + } else { + throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext"); + } + } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj; + if (mlCtx.getSparkMonitoringUtil() != null) { + mlCtx.getSparkMonitoringUtil().setLineageInfo(inst, outDebugString); + } else { + throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext"); + } + } + + } else { throw new DMLRuntimeException("The method setLineageInfoForExplain should be called only through MLContext"); } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java index 0c0d3f0..d5301e7 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/SPInstruction.java @@ -19,7 +19,6 @@ package org.apache.sysml.runtime.instructions.spark; -import org.apache.sysml.api.MLContext; import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.lops.runtime.RunMRJobs; import org.apache.sysml.runtime.DMLRuntimeException; @@ -99,13 +98,23 @@ public abstract class SPInstruction extends Instruction //spark-explain-specific handling of current instructions //This only relevant for ComputationSPInstruction as in postprocess we call setDebugString which is valid only for ComputationSPInstruction - MLContext mlCtx = MLContextProxy.getActiveMLContext(); - if( tmp instanceof ComputationSPInstruction - && mlCtx != null && mlCtx.getMonitoringUtil() != null - && ec instanceof SparkExecutionContext ) - { - mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp); - MLContextProxy.setInstructionForMonitoring(tmp); + Object mlCtxObj = MLContextProxy.getActiveMLContext(); + if (mlCtxObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj; + if (tmp instanceof ComputationSPInstruction + && mlCtx != null && mlCtx.getMonitoringUtil() != null + && ec instanceof SparkExecutionContext ) { + mlCtx.getMonitoringUtil().addCurrentInstruction((SPInstruction)tmp); + MLContextProxy.setInstructionForMonitoring(tmp); + } + } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj; + if (tmp instanceof ComputationSPInstruction + && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null + && ec instanceof SparkExecutionContext ) { + mlCtx.getSparkMonitoringUtil().addCurrentInstruction((SPInstruction)tmp); + MLContextProxy.setInstructionForMonitoring(tmp); + } } return tmp; @@ -120,14 +129,25 @@ public abstract class SPInstruction extends Instruction throws DMLRuntimeException { //spark-explain-specific handling of current instructions - MLContext mlCtx = MLContextProxy.getActiveMLContext(); - if( this instanceof ComputationSPInstruction - && mlCtx != null && mlCtx.getMonitoringUtil() != null - && ec instanceof SparkExecutionContext ) - { - SparkExecutionContext sec = (SparkExecutionContext) ec; - sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName()); - mlCtx.getMonitoringUtil().removeCurrentInstruction(this); + Object mlCtxObj = MLContextProxy.getActiveMLContext(); + if (mlCtxObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlCtx = (org.apache.sysml.api.MLContext) mlCtxObj; + if (this instanceof ComputationSPInstruction + && mlCtx != null && mlCtx.getMonitoringUtil() != null + && ec instanceof SparkExecutionContext ) { + SparkExecutionContext sec = (SparkExecutionContext) ec; + sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName()); + mlCtx.getMonitoringUtil().removeCurrentInstruction(this); + } + } else if (mlCtxObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj; + if (this instanceof ComputationSPInstruction + && mlCtx != null && mlCtx.getSparkMonitoringUtil() != null + && ec instanceof SparkExecutionContext ) { + SparkExecutionContext sec = (SparkExecutionContext) ec; + sec.setDebugString(this, ((ComputationSPInstruction) this).getOutputVariableName()); + mlCtx.getSparkMonitoringUtil().removeCurrentInstruction(this); + } } //maintain statistics http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java index 956b841..3bf2f67 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/functions/SparkListener.java @@ -33,16 +33,14 @@ import org.apache.spark.storage.RDDInfo; import org.apache.spark.ui.jobs.StagesTab; import org.apache.spark.ui.jobs.UIData.TaskUIData; import org.apache.spark.ui.scope.RDDOperationGraphListener; +import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.runtime.instructions.spark.SPInstruction; import scala.Option; import scala.collection.Iterator; import scala.collection.Seq; import scala.xml.Node; -import org.apache.sysml.api.MLContext; -import org.apache.sysml.api.MLContextProxy; -import org.apache.sysml.runtime.instructions.spark.SPInstruction; - // Instead of extending org.apache.spark.JavaSparkListener /** * This class is only used by MLContext for now. It is used to provide UI data for Python notebook. @@ -94,9 +92,19 @@ public class SparkListener extends RDDOperationGraphListener { jobDAGs.put(jobID, jobNodes); synchronized(currentInstructions) { for(SPInstruction inst : currentInstructions) { - MLContext mlContext = MLContextProxy.getActiveMLContext(); - if(mlContext != null && mlContext.getMonitoringUtil() != null) { - mlContext.getMonitoringUtil().setJobId(inst, jobID); + Object mlContextObj = MLContextProxy.getActiveMLContext(); + if (mlContextObj != null) { + if (mlContextObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj; + if (mlContext.getMonitoringUtil() != null) { + mlContext.getMonitoringUtil().setJobId(inst, jobID); + } + } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj; + if (mlContext.getSparkMonitoringUtil() != null) { + mlContext.getSparkMonitoringUtil().setJobId(inst, jobID); + } + } } } } @@ -140,9 +148,19 @@ public class SparkListener extends RDDOperationGraphListener { synchronized(currentInstructions) { for(SPInstruction inst : currentInstructions) { - MLContext mlContext = MLContextProxy.getActiveMLContext(); - if(mlContext != null && mlContext.getMonitoringUtil() != null) { - mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId()); + Object mlContextObj = MLContextProxy.getActiveMLContext(); + if (mlContextObj != null) { + if (mlContextObj instanceof org.apache.sysml.api.MLContext) { + org.apache.sysml.api.MLContext mlContext = (org.apache.sysml.api.MLContext) mlContextObj; + if (mlContext.getMonitoringUtil() != null) { + mlContext.getMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId()); + } + } else if (mlContextObj instanceof org.apache.sysml.api.mlcontext.MLContext) { + org.apache.sysml.api.mlcontext.MLContext mlContext = (org.apache.sysml.api.mlcontext.MLContext) mlContextObj; + if (mlContext.getSparkMonitoringUtil() != null) { + mlContext.getSparkMonitoringUtil().setStageId(inst, stageSubmitted.stageInfo().stageId()); + } + } } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java index ccdc927..f022e40 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/spark/utils/RDDConverterUtilsExt.java @@ -410,7 +410,7 @@ public class RDDConverterUtilsExt } - private static class DataFrameAnalysisFunction implements Function<Row,Row> { + public static class DataFrameAnalysisFunction implements Function<Row,Row> { private static final long serialVersionUID = 5705371332119770215L; private RowAnalysisFunctionHelper helper = null; boolean isVectorBasedRDD; @@ -445,7 +445,7 @@ public class RDDConverterUtilsExt } - private static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> { + public static class DataFrameToBinaryBlockFunction implements PairFlatMapFunction<Iterator<Tuple2<Row,Long>>,MatrixIndexes,MatrixBlock> { private static final long serialVersionUID = 653447740362447236L; private RowToBinaryBlockFunctionHelper helper = null; boolean isVectorBasedDF;
