Repository: incubator-systemml Updated Branches: refs/heads/master d2b9e5022 -> f3b1aebc2
[SYSTEMML-1236] Move MLContextProxy out of api package MLContextProxy is used internally by SystemML but not by end users, so move this class to a non-API package. Remove some unnecessary casting since only 1 MLContext class now exists. Closes #522. Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/f3b1aebc Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/f3b1aebc Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/f3b1aebc Branch: refs/heads/master Commit: f3b1aebc2b238f456eb7931d44f4f9b2d472a254 Parents: d2b9e50 Author: Deron Eriksson <de...@us.ibm.com> Authored: Thu Jun 1 11:36:47 2017 -0700 Committer: Deron Eriksson <de...@us.ibm.com> Committed: Thu Jun 1 11:36:47 2017 -0700 ---------------------------------------------------------------------- .../org/apache/sysml/api/MLContextProxy.java | 73 ---- .../apache/sysml/api/mlcontext/MLContext.java | 22 +- .../sysml/api/mlcontext/MLContextUtil.java | 98 ++--- .../org/apache/sysml/parser/StatementBlock.java | 432 +++++++++---------- .../runtime/controlprogram/ProgramBlock.java | 166 +++---- .../context/SparkExecutionContext.java | 8 +- .../org/apache/sysml/utils/MLContextProxy.java | 73 ++++ 7 files changed, 436 insertions(+), 436 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/api/MLContextProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLContextProxy.java b/src/main/java/org/apache/sysml/api/MLContextProxy.java deleted file mode 100644 index 18b2eaa..0000000 --- a/src/main/java/org/apache/sysml/api/MLContextProxy.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysml.api; - -import java.util.ArrayList; - -import org.apache.sysml.api.mlcontext.MLContext; -import org.apache.sysml.api.mlcontext.MLContextException; -import org.apache.sysml.parser.Expression; -import org.apache.sysml.parser.LanguageException; -import org.apache.sysml.runtime.instructions.Instruction; - -/** - * The purpose of this proxy is to shield systemml internals from direct access to MLContext - * which would try to load spark libraries and hence fail if these are not available. This - * indirection is much more efficient than catching NoClassDefFoundErrors for every access - * to MLContext (e.g., on each recompile). - * - */ -public class MLContextProxy -{ - - private static boolean _active = false; - - public static void setActive(boolean flag) { - _active = flag; - } - - public static boolean isActive() { - return _active; - } - - public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp) - { - return MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp); - } - - public static void setAppropriateVarsForRead(Expression source, String targetname) - throws LanguageException - { - MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname); - } - - public static Object getActiveMLContext() { - return MLContext.getActiveMLContext(); - } - - public static Object getActiveMLContextForAPI() { - if (MLContext.getActiveMLContext() != null) { - return MLContext.getActiveMLContext(); - } - throw new MLContextException("No MLContext object is currently active. Have you created one? " - + "Hint: in Scala, 'val ml = new MLContext(sc)'", true); - } - -} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java index 272fa0e..8ea9fcc 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -32,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; -import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.api.jmlc.JMLCUtils; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.conf.DMLConfig; @@ -50,6 +49,7 @@ import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; import org.apache.sysml.runtime.matrix.data.OutputInfo; import org.apache.sysml.utils.Explain.ExplainType; +import org.apache.sysml.utils.MLContextProxy; /** * The MLContext API offers programmatic access to SystemML on Spark from @@ -98,7 +98,7 @@ public class MLContext { * Whether or not GPU mode should be enabled */ private boolean gpu = false; - + /** * Whether or not GPU mode should be force */ @@ -172,7 +172,7 @@ public class MLContext { /** * Create an MLContext based on a SparkSession for interaction with SystemML * on Spark. - * + * * @param spark SparkSession */ public MLContext(SparkSession spark) { @@ -265,7 +265,7 @@ public class MLContext { throw new MLContextException(e); } } - + /** * Execute a DML or PYDML Script. * @@ -364,7 +364,7 @@ public class MLContext { /** * Obtain whether or not all values should be maintained in the symbol table * after execution. - * + * * @return {@code true} if all values should be maintained in the symbol * table, {@code false} otherwise */ @@ -375,7 +375,7 @@ public class MLContext { /** * Set whether or not all values should be maintained in the symbol table * after execution. - * + * * @param maintainSymbolTable * {@code true} if all values should be maintained in the symbol * table, {@code false} otherwise @@ -425,7 +425,7 @@ public class MLContext { public void setGPU(boolean enable) { this.gpu = enable; } - + /** * Whether or not to explicitly "force" the usage of GPU. * If a GPU is not available, and the GPU mode is set or if available memory on GPU is less, SystemML will crash when the program is run. @@ -657,7 +657,7 @@ public class MLContext { /** * Obtain information about the project such as version and build time from * the manifest in the SystemML jar file. - * + * * @return information about the project */ public ProjectInfo info() { @@ -672,7 +672,7 @@ public class MLContext { /** * Obtain the SystemML version number. - * + * * @return the SystemML version number */ public String version() { @@ -684,7 +684,7 @@ public class MLContext { /** * Obtain the SystemML jar file build time. - * + * * @return the SystemML jar file build time */ public String buildTime() { @@ -697,7 +697,7 @@ public class MLContext { /** * Obtain the maximum number of heavy hitters that are printed out as part * of the statistics. - * + * * @return maximum number of heavy hitters to print */ public int getStatisticsMaxHeavyHitters() { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java index d5b48bc..43365dd 100644 --- a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -52,7 +52,6 @@ import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.CompilerConfig; import org.apache.sysml.conf.CompilerConfig.ConfigType; import org.apache.sysml.conf.ConfigurationManager; @@ -78,6 +77,7 @@ import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.matrix.data.FrameBlock; import org.apache.sysml.runtime.matrix.data.MatrixBlock; import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.utils.MLContextProxy; import org.w3c.dom.Document; import org.w3c.dom.Node; import org.w3c.dom.NodeList; @@ -111,7 +111,7 @@ public final class MLContextUtil { /** * Compare two version strings (ie, "1.4.0" and "1.4.1"). - * + * * @param versionStr1 * First version string. * @param versionStr2 @@ -157,7 +157,7 @@ public final class MLContextUtil { /** * Determine whether the Spark version is supported. - * + * * @param sparkVersion * Spark version string (ie, "1.5.0"). * @param minimumRecommendedSparkVersion @@ -171,7 +171,7 @@ public final class MLContextUtil { /** * Check that the Spark version is supported. If it isn't supported, throw * an MLContextException. - * + * * @param spark * SparkSession * @throws MLContextException @@ -203,7 +203,7 @@ public final class MLContextUtil { /** * Obtain minimum recommended Spark version from the pom.xml file. - * + * * @return the minimum recommended Spark version from XML parsing of the pom file (during development). */ static String getMinimumRecommendedSparkVersionFromPom() { @@ -213,7 +213,7 @@ public final class MLContextUtil { /** * Obtain the text associated with an XML element from the pom.xml file. In this implementation, * the element should be uniquely named, or results will be unpredicable. - * + * * @param property unique property (element) from the pom.xml file * @return the text value associated with the given property */ @@ -249,7 +249,7 @@ public final class MLContextUtil { /** * Set SystemML configuration properties based on a configuration file. - * + * * @param configFilePath * Path to configuration file. * @throws MLContextException @@ -281,7 +281,7 @@ public final class MLContextUtil { /** * Verify that the types of input values are supported. - * + * * @param inputs * Map of String/Object pairs * @throws MLContextException @@ -295,7 +295,7 @@ public final class MLContextUtil { /** * Verify that the type of input value is supported. - * + * * @param name * The name of the input * @param value @@ -329,7 +329,7 @@ public final class MLContextUtil { /** * Verify that the type of input parameter value is supported. - * + * * @param parameterName * The name of the input parameter * @param parameterValue @@ -367,7 +367,7 @@ public final class MLContextUtil { /** * Is the object one of the supported basic data types? (Integer, Boolean, * Double, String) - * + * * @param object * the object type to be examined * @return {@code true} if type is a basic data type; otherwise @@ -387,7 +387,7 @@ public final class MLContextUtil { /** * Obtain the SystemML scalar value type string equivalent of an accepted * basic type (Integer, Boolean, Double, String) - * + * * @param object * the object type to be examined * @return a String representing the type as a SystemML scalar value type @@ -413,7 +413,7 @@ public final class MLContextUtil { /** * Is the object one of the supported complex data types? (JavaRDD, RDD, * DataFrame, BinaryBlockMatrix, Matrix, double[][], MatrixBlock, URL) - * + * * @param object * the object type to be examined * @return {@code true} if type is a complex data type; otherwise @@ -433,7 +433,7 @@ public final class MLContextUtil { /** * Converts non-string basic input parameter values to strings to pass to * the parser. - * + * * @param basicInputParameterMap * map of input parameters * @param scriptType @@ -475,7 +475,7 @@ public final class MLContextUtil { /** * Convert input types to internal SystemML representations - * + * * @param parameterName * The name of the input parameter * @param parameterValue @@ -488,7 +488,7 @@ public final class MLContextUtil { /** * Convert input types to internal SystemML representations - * + * * @param parameterName * The name of the input parameter * @param parameterValue @@ -626,7 +626,7 @@ public final class MLContextUtil { /** * If no metadata is supplied for an RDD or JavaRDD, this method can be used * to determine whether the data appears to be matrix (or a frame) - * + * * @param line * a line of the RDD * @return {@code true} if all the csv-separated values are numbers, @@ -651,7 +651,7 @@ public final class MLContextUtil { /** * Examine the DataFrame schema to determine whether the data appears to be * a matrix. - * + * * @param df * the DataFrame * @return {@code true} if the DataFrame appears to be a matrix, @@ -684,7 +684,7 @@ public final class MLContextUtil { /** * Return a double-quoted string with inner single and double quotes * escaped. - * + * * @param str * the original string * @return double-quoted string with inner single and double quotes escaped @@ -714,7 +714,7 @@ public final class MLContextUtil { /** * Display the keys and values in a Map - * + * * @param mapName * the name of the map * @param map @@ -745,7 +745,7 @@ public final class MLContextUtil { /** * Display the values in a Set - * + * * @param setName * the name of the Set * @param set @@ -773,7 +773,7 @@ public final class MLContextUtil { /** * Display the keys and values in the symbol table - * + * * @param name * the name of the symbol table * @param symbolTable @@ -790,7 +790,7 @@ public final class MLContextUtil { /** * Display the keys and values in the symbol table - * + * * @param symbolTable * the LocalVariableMap * @return the keys and values in the symbol table as a String @@ -823,7 +823,7 @@ public final class MLContextUtil { /** * Obtain a symbol table output type as a String - * + * * @param symbolTable * the symbol table * @param outputName @@ -850,7 +850,7 @@ public final class MLContextUtil { /** * Obtain a display of script inputs. - * + * * @param name * the title to display for the inputs * @param map @@ -896,9 +896,9 @@ public final class MLContextUtil { String str = null; if(object instanceof MatrixBlock) { MatrixBlock mb = (MatrixBlock) object; - str = "MatrixBlock [sparse? = " + mb.isInSparseFormat() + ", nonzeros = " + mb.getNonZeros() + ", size: " + mb.getNumRows() + " X " + mb.getNumColumns() + "]"; + str = "MatrixBlock [sparse? = " + mb.isInSparseFormat() + ", nonzeros = " + mb.getNonZeros() + ", size: " + mb.getNumRows() + " X " + mb.getNumColumns() + "]"; } - else + else str = object.toString(); // TODO: Deal with OOM for other objects such as Frame, etc str = StringUtils.abbreviate(str, 100); sb.append(str); @@ -910,7 +910,7 @@ public final class MLContextUtil { /** * Obtain a display of the script outputs. - * + * * @param name * the title to display for the outputs * @param outputNames @@ -918,7 +918,7 @@ public final class MLContextUtil { * @param symbolTable * the symbol table * @return the script outputs represented as a String - * + * */ public static String displayOutputs(String name, Set<String> outputNames, LocalVariableMap symbolTable) { StringBuilder sb = new StringBuilder(); @@ -930,13 +930,13 @@ public final class MLContextUtil { /** * Obtain a display of the script outputs. - * + * * @param outputNames * the names of the output variables * @param symbolTable * the symbol table * @return the script outputs represented as a String - * + * */ public static String displayOutputs(Set<String> outputNames, LocalVariableMap symbolTable) { StringBuilder sb = new StringBuilder(); @@ -970,7 +970,7 @@ public final class MLContextUtil { /** * The SystemML welcome message - * + * * @return the SystemML welcome message */ public static String welcomeMessage() { @@ -989,7 +989,7 @@ public final class MLContextUtil { /** * Generate a String history entry for a script. - * + * * @param script * the script * @param when @@ -1009,7 +1009,7 @@ public final class MLContextUtil { /** * Generate a String listing of the script execution history. - * + * * @param scriptHistory * the list of script history entries * @return the listing of the script execution history as a String @@ -1031,7 +1031,7 @@ public final class MLContextUtil { /** * Obtain the Spark Context - * + * * @param mlContext * the SystemML MLContext * @return the Spark Context @@ -1042,7 +1042,7 @@ public final class MLContextUtil { /** * Obtain the Java Spark Context - * + * * @param mlContext * the SystemML MLContext * @return the Java Spark Context @@ -1053,39 +1053,39 @@ public final class MLContextUtil { /** * Obtain the Spark Context from the MLContextProxy - * + * * @return the Spark Context */ public static SparkContext getSparkContextFromProxy() { - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); + MLContext activeMLContext = MLContextProxy.getActiveMLContextForAPI(); SparkContext sc = getSparkContext(activeMLContext); return sc; } /** * Obtain the Java Spark Context from the MLContextProxy - * + * * @return the Java Spark Context */ public static JavaSparkContext getJavaSparkContextFromProxy() { - MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContextForAPI(); + MLContext activeMLContext = MLContextProxy.getActiveMLContextForAPI(); JavaSparkContext jsc = getJavaSparkContext(activeMLContext); return jsc; } /** * Obtain the Spark Session from the MLContextProxy - * + * * @return the Spark Session */ public static SparkSession getSparkSessionFromProxy() { - return ((MLContext) MLContextProxy.getActiveMLContextForAPI()).getSparkSession(); + return MLContextProxy.getActiveMLContextForAPI().getSparkSession(); } /** * Determine if the symbol table contains a FrameObject with the given * variable name. - * + * * @param symbolTable * the LocalVariableMap * @param variableName @@ -1101,7 +1101,7 @@ public final class MLContextUtil { /** * Determine if the symbol table contains a MatrixObject with the given * variable name. - * + * * @param symbolTable * the LocalVariableMap * @param variableName @@ -1116,7 +1116,7 @@ public final class MLContextUtil { /** * Delete the 'remove variable' instructions from a runtime program. - * + * * @param progam * runtime program */ @@ -1139,7 +1139,7 @@ public final class MLContextUtil { /** * Recursively traverse program block to delete 'remove variable' * instructions. - * + * * @param pb * Program block */ @@ -1166,7 +1166,7 @@ public final class MLContextUtil { /** * Delete 'remove variable' instructions. - * + * * @param instructions * list of instructions */ http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/parser/StatementBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/parser/StatementBlock.java b/src/main/java/org/apache/sysml/parser/StatementBlock.java index 3947372..79c7a9a 100644 --- a/src/main/java/org/apache/sysml/parser/StatementBlock.java +++ b/src/main/java/org/apache/sysml/parser/StatementBlock.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -27,7 +27,6 @@ import java.util.List; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.HopsException; @@ -39,29 +38,30 @@ import org.apache.sysml.parser.Expression.ValueType; import org.apache.sysml.parser.LanguageException.LanguageErrorCodes; import org.apache.sysml.parser.PrintStatement.PRINTTYPE; import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence; +import org.apache.sysml.utils.MLContextProxy; public class StatementBlock extends LiveVariableAnalysis { - + protected static final Log LOG = LogFactory.getLog(StatementBlock.class.getName()); protected static IDSequence _seq = new IDSequence(); - - protected DMLProgram _dmlProg; + + protected DMLProgram _dmlProg; protected ArrayList<Statement> _statements; ArrayList<Hop> _hops = null; ArrayList<Lop> _lops = null; HashMap<String,ConstIdentifier> _constVarsIn; HashMap<String,ConstIdentifier> _constVarsOut; - + private ArrayList<String> _updateInPlaceVars = null; private boolean _requiresRecompile = false; - + public StatementBlock() { _dmlProg = null; _statements = new ArrayList<Statement>(); _read = new VariableSet(); - _updated = new VariableSet(); + _updated = new VariableSet(); _gen = new VariableSet(); _kill = new VariableSet(); _warnSet = new VariableSet(); @@ -70,41 +70,41 @@ public class StatementBlock extends LiveVariableAnalysis _constVarsOut = new HashMap<String,ConstIdentifier>(); _updateInPlaceVars = new ArrayList<String>(); } - + public void setDMLProg(DMLProgram dmlProg){ _dmlProg = dmlProg; } - + public DMLProgram getDMLProg(){ return _dmlProg; } - + public void addStatement(Statement s){ _statements.add(s); - + if (_statements.size() == 1){ this._filename = s.getFilename(); - this._beginLine = s.getBeginLine(); + this._beginLine = s.getBeginLine(); this._beginColumn = s.getBeginColumn(); } - + this._endLine = s.getEndLine(); this._endColumn = s.getEndColumn(); - + } public void addStatementBlock(StatementBlock s){ for (int i = 0; i < s.getNumStatements(); i++){ _statements.add(s.getStatement(i)); } - - this._beginLine = _statements.get(0).getBeginLine(); + + this._beginLine = _statements.get(0).getBeginLine(); this._beginColumn = _statements.get(0).getBeginColumn(); - + this._endLine = _statements.get(_statements.size() - 1).getEndLine(); this._endColumn = _statements.get(_statements.size() - 1).getEndColumn(); } - + public int getNumStatements(){ return _statements.size(); } @@ -112,12 +112,12 @@ public class StatementBlock extends LiveVariableAnalysis public Statement getStatement(int i){ return _statements.get(i); } - + public ArrayList<Statement> getStatements() { return _statements; } - + public void setStatements( ArrayList<Statement> s ) { _statements = s; @@ -140,30 +140,30 @@ public class StatementBlock extends LiveVariableAnalysis } public boolean mergeable(){ - for (Statement s : _statements){ + for (Statement s : _statements){ if (s.controlStatement()) return false; } return true; } - + public boolean isMergeableFunctionCallBlock(DMLProgram dmlProg) throws LanguageException{ - + // if (DMLScript.ENABLE_DEBUG_MODE && !DMLScript.ENABLE_DEBUG_OPTIMIZER) if (DMLScript.ENABLE_DEBUG_MODE) return false; - - // check whether targetIndex stmt block is for a mergable function call + + // check whether targetIndex stmt block is for a mergable function call Statement stmt = this.getStatement(0); - + // Check whether targetIndex block is: control stmt block or stmt block for un-mergable function call - if ( stmt instanceof WhileStatement || stmt instanceof IfStatement || stmt instanceof ForStatement + if ( stmt instanceof WhileStatement || stmt instanceof IfStatement || stmt instanceof ForStatement || stmt instanceof FunctionStatement || ( stmt instanceof PrintStatement && ((PrintStatement)stmt).getType() == PRINTTYPE.STOP )/*|| stmt instanceof ELStatement*/ ) { return false; } - + if (stmt instanceof AssignmentStatement || stmt instanceof MultiAssignmentStatement){ Expression sourceExpr = null; if (stmt instanceof AssignmentStatement) { @@ -178,7 +178,7 @@ public class StatementBlock extends LiveVariableAnalysis if ( (sourceExpr instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)sourceExpr).multipleReturns()) || (sourceExpr instanceof ParameterizedBuiltinFunctionExpression && ((ParameterizedBuiltinFunctionExpression)sourceExpr).multipleReturns())) return false; - + //function calls (only mergable if inlined dml-bodied function) if (sourceExpr instanceof FunctionCallIdentifier){ FunctionCallIdentifier fcall = (FunctionCallIdentifier) sourceExpr; @@ -190,13 +190,13 @@ public class StatementBlock extends LiveVariableAnalysis return false; } } - + // regular function block return true; } - + public boolean isRewritableFunctionCall(Statement stmt, DMLProgram dmlProg) throws LanguageException{ - + // for regular stmt, check if this is a function call stmt block if (stmt instanceof AssignmentStatement || stmt instanceof MultiAssignmentStatement){ Expression sourceExpr = null; @@ -204,53 +204,53 @@ public class StatementBlock extends LiveVariableAnalysis sourceExpr = ((AssignmentStatement)stmt).getSource(); else sourceExpr = ((MultiAssignmentStatement)stmt).getSource(); - + if (sourceExpr instanceof FunctionCallIdentifier){ FunctionCallIdentifier fcall = (FunctionCallIdentifier) sourceExpr; FunctionStatementBlock fblock = dmlProg.getFunctionStatementBlock(fcall.getNamespace(),fcall.getName()); if (fblock == null) { throw new LanguageException(sourceExpr.printErrorLocation() + "function " + fcall.getName() + " is undefined in namespace " + fcall.getNamespace()); } - + //check for unsupported target indexed identifiers (for consistent error handling) - if( stmt instanceof AssignmentStatement + if( stmt instanceof AssignmentStatement && ((AssignmentStatement)stmt).getTarget() instanceof IndexedIdentifier ) { return false; } - + //check if function can be inlined if( rIsInlineableFunction(fblock, dmlProg) ) { return true; } } } - + // regular statement return false; } - + private boolean rIsInlineableFunction( FunctionStatementBlock fblock, DMLProgram prog ) { boolean ret = true; - + //reject external functions and function bodies with multiple blocks if( fblock.getStatements().isEmpty() //empty blocks - || fblock.getStatement(0) instanceof ExternalFunctionStatement + || fblock.getStatement(0) instanceof ExternalFunctionStatement || ((FunctionStatement)fblock.getStatement(0)).getBody().size() > 1 ) { return false; } - + //reject control flow and non-inlinable functions - if(!fblock.getStatements().isEmpty() && !((FunctionStatement)fblock.getStatement(0)).getBody().isEmpty()) + if(!fblock.getStatements().isEmpty() && !((FunctionStatement)fblock.getStatement(0)).getBody().isEmpty()) { StatementBlock stmtBlock = ((FunctionStatement)fblock.getStatement(0)).getBody().get(0); - + //reject control flow blocks if (stmtBlock instanceof IfStatementBlock || stmtBlock instanceof WhileStatementBlock || stmtBlock instanceof ForStatementBlock) return false; - + //recursively check that functions are inlinable for( Statement s : stmtBlock.getStatements() ){ if( s instanceof AssignmentStatement && ((AssignmentStatement)s).getSource() instanceof FunctionCallIdentifier ) @@ -261,7 +261,7 @@ public class StatementBlock extends LiveVariableAnalysis ret &= rIsInlineableFunction(fblock2, prog); if( as.getSource().toString().contains(DataExpression.FORMAT_TYPE + "=" + DataExpression.FORMAT_TYPE_VALUE_CSV) && as.getSource().toString().contains("read")) return false; - + if( !ret ) return false; } if( s instanceof MultiAssignmentStatement && ((MultiAssignmentStatement)s).getSource() instanceof FunctionCallIdentifier ) @@ -273,39 +273,39 @@ public class StatementBlock extends LiveVariableAnalysis } } } - + return ret; } - + public static ArrayList<StatementBlock> mergeFunctionCalls(ArrayList<StatementBlock> body, DMLProgram dmlProg) throws LanguageException { for(int i = 0; i <body.size(); i++){ - + StatementBlock currBlock = body.get(i); - + // recurse to children function statement blocks if (currBlock instanceof WhileStatementBlock){ WhileStatement wstmt = (WhileStatement)((WhileStatementBlock)currBlock).getStatement(0); - wstmt.setBody(mergeFunctionCalls(wstmt.getBody(),dmlProg)); + wstmt.setBody(mergeFunctionCalls(wstmt.getBody(),dmlProg)); } - + else if (currBlock instanceof ForStatementBlock){ ForStatement fstmt = (ForStatement)((ForStatementBlock)currBlock).getStatement(0); - fstmt.setBody(mergeFunctionCalls(fstmt.getBody(),dmlProg)); + fstmt.setBody(mergeFunctionCalls(fstmt.getBody(),dmlProg)); } - + else if (currBlock instanceof IfStatementBlock){ IfStatement ifstmt = (IfStatement)((IfStatementBlock)currBlock).getStatement(0); - ifstmt.setIfBody(mergeFunctionCalls(ifstmt.getIfBody(),dmlProg)); + ifstmt.setIfBody(mergeFunctionCalls(ifstmt.getIfBody(),dmlProg)); ifstmt.setElseBody(mergeFunctionCalls(ifstmt.getElseBody(),dmlProg)); } - + else if (currBlock instanceof FunctionStatementBlock){ FunctionStatement functStmt = (FunctionStatement)((FunctionStatementBlock)currBlock).getStatement(0); - functStmt.setBody(mergeFunctionCalls(functStmt.getBody(),dmlProg)); + functStmt.setBody(mergeFunctionCalls(functStmt.getBody(),dmlProg)); } } - + ArrayList<StatementBlock> result = new ArrayList<StatementBlock>(); StatementBlock currentBlock = null; @@ -330,10 +330,10 @@ public class StatementBlock extends LiveVariableAnalysis if (currentBlock != null) { result.add(currentBlock); } - - return result; + + return result; } - + public String toString(){ StringBuilder sb = new StringBuilder(); sb.append("statements\n"); @@ -380,79 +380,79 @@ public class StatementBlock extends LiveVariableAnalysis if (currentBlock != null) { result.add(currentBlock); } - + return result; } - + public ArrayList<Statement> rewriteFunctionCallStatements (DMLProgram dmlProg, ArrayList<Statement> statements) throws LanguageException { - + ArrayList<Statement> newStatements = new ArrayList<Statement>(); for (Statement current : statements){ if (isRewritableFunctionCall(current, dmlProg)){ - + Expression sourceExpr = null; if (current instanceof AssignmentStatement) sourceExpr = ((AssignmentStatement)current).getSource(); else sourceExpr = ((MultiAssignmentStatement)current).getSource(); - + FunctionCallIdentifier fcall = (FunctionCallIdentifier) sourceExpr; FunctionStatementBlock fblock = dmlProg.getFunctionStatementBlock(fcall.getNamespace(), fcall.getName()); if (fblock == null){ fcall.raiseValidateError("function " + fcall.getName() + " is undefined in namespace " + fcall.getNamespace(), false); } FunctionStatement fstmt = (FunctionStatement)fblock.getStatement(0); - + // recursive inlining (no memo required because update-inplace of function statement blocks, so no redundant inlining) if( rIsInlineableFunction(fblock, dmlProg) ){ fstmt.getBody().get(0).setStatements(rewriteFunctionCallStatements(dmlProg, fstmt.getBody().get(0).getStatements())); } - + //MB: we cannot use the hash since multiple interleaved inlined functions should be independent. //String prefix = new Integer(fblock.hashCode()).toString() + "_"; String prefix = _seq.getNextID() + "_"; - + if (fstmt.getBody().size() > 1){ sourceExpr.raiseValidateError("rewritable function can only have 1 statement block", false); } StatementBlock sblock = fstmt.getBody().get(0); - + if( fcall.getParamExprs().size() < fstmt.getInputParams().size() ) { sourceExpr.raiseValidateError("Wrong number of function parameters: "+ fcall.getParamExprs().size() + ", but " + fstmt.getInputParams().size()+" expected."); } - + for (int i =0; i < fstmt.getInputParams().size(); i++) { DataIdentifier currFormalParam = fstmt.getInputParams().get(i); - + // create new assignment statement String newFormalParameterName = prefix + currFormalParam.getName(); DataIdentifier newTarget = new DataIdentifier(currFormalParam); newTarget.setName(newFormalParameterName); - + Expression currCallParam = fcall.getParamExprs().get(i).getExpr(); - + //auto casting of inputs on inlining (if required) ValueType targetVT = newTarget.getValueType(); - if( newTarget.getDataType()==DataType.SCALAR && currCallParam.getOutput() != null + if( newTarget.getDataType()==DataType.SCALAR && currCallParam.getOutput() != null && targetVT != currCallParam.getOutput().getValueType() && targetVT != ValueType.STRING ) { - currCallParam = new BuiltinFunctionExpression(BuiltinFunctionExpression.getValueTypeCastOperator(targetVT), new Expression[] {currCallParam}, - newTarget.getFilename(), newTarget.getBeginLine(), newTarget.getBeginColumn(), newTarget.getEndLine(), newTarget.getEndColumn()); + currCallParam = new BuiltinFunctionExpression(BuiltinFunctionExpression.getValueTypeCastOperator(targetVT), new Expression[] {currCallParam}, + newTarget.getFilename(), newTarget.getBeginLine(), newTarget.getBeginColumn(), newTarget.getEndLine(), newTarget.getEndColumn()); } - + // create the assignment statement to bind the call parameter to formal parameter AssignmentStatement binding = new AssignmentStatement(newTarget, currCallParam, newTarget.getBeginLine(), newTarget.getBeginColumn(), newTarget.getEndLine(), newTarget.getEndColumn()); newStatements.add(binding); } - + for (Statement stmt : sblock._statements){ - - // rewrite the statement to use the "rewritten" name + + // rewrite the statement to use the "rewritten" name Statement rewrittenStmt = stmt.rewriteStatement(prefix); - newStatements.add(rewrittenStmt); + newStatements.add(rewrittenStmt); } if (current instanceof AssignmentStatement) { @@ -473,14 +473,14 @@ public class StatementBlock extends LiveVariableAnalysis } // handle the return values for (int i = 0; i < fstmt.getOutputParams().size(); i++){ - + // get the target (return parameter from function) DataIdentifier currReturnParam = fstmt.getOutputParams().get(i); String newSourceName = prefix + currReturnParam.getName(); DataIdentifier newSource = new DataIdentifier(currReturnParam); newSource.setName(newSourceName); - - // get binding + + // get binding DataIdentifier newTarget = null; if (current instanceof AssignmentStatement){ if (i > 0) { @@ -500,90 +500,90 @@ public class StatementBlock extends LiveVariableAnalysis else{ newTarget = new DataIdentifier(((MultiAssignmentStatement)current).getTargetList().get(i)); } - - //auto casting of inputs on inlining (always, redundant cast removed during Hop Rewrites) + + //auto casting of inputs on inlining (always, redundant cast removed during Hop Rewrites) ValueType sourceVT = newSource.getValueType(); if( newSource.getDataType()==DataType.SCALAR && sourceVT != ValueType.STRING ){ newSource = new BuiltinFunctionExpression(BuiltinFunctionExpression.getValueTypeCastOperator(sourceVT), new Expression[] {newSource}, newTarget.getFilename(), newTarget.getBeginLine(), newTarget.getBeginColumn(), newTarget.getEndLine(), newTarget.getEndColumn()); } - + // create the assignment statement to bind the call parameter to formal parameter AssignmentStatement binding = new AssignmentStatement(newTarget, newSource, newTarget.getBeginLine(), newTarget.getBeginColumn(), newTarget.getEndLine(), newTarget.getEndColumn()); - + newStatements.add(binding); } - + } // end if (isRewritableFunctionCall(current, dmlProg) - + else { newStatements.add(current); } } - + return newStatements; } - - public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) - throws LanguageException, ParseException, IOException + + public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String, ConstIdentifier> constVars, boolean conditional) + throws LanguageException, ParseException, IOException { _constVarsIn.putAll(constVars); HashMap<String, ConstIdentifier> currConstVars = new HashMap<String,ConstIdentifier>(); currConstVars.putAll(constVars); - + _statements = rewriteFunctionCallStatements(dmlProg, _statements); _dmlProg = dmlProg; - + for (Statement current : _statements){ - + if (current instanceof OutputStatement){ OutputStatement os = (OutputStatement)current; - + // validate variable being written by output statement exists DataIdentifier target = (DataIdentifier)os.getIdentifier(); if (ids.getVariable(target.getName()) == null) { //undefined variables are always treated unconditionally as error in order to prevent common script-level bugs raiseValidateError("Undefined Variable (" + target.getName() + ") used in statement", false, LanguageErrorCodes.INVALID_PARAMETERS); } - + if ( ids.getVariable(target.getName()).getDataType() == DataType.SCALAR) { boolean paramsOkay = true; for (String key : os.getSource().getVarParams().keySet()){ - if (! (key.equals(DataExpression.IO_FILENAME) || key.equals(DataExpression.FORMAT_TYPE))) + if (! (key.equals(DataExpression.IO_FILENAME) || key.equals(DataExpression.FORMAT_TYPE))) paramsOkay = false; } if( !paramsOkay ) { raiseValidateError("Invalid parameters in write statement: " + os.toString(), conditional); } } - + Expression source = os.getSource(); source.setOutput(target); source.validateExpression(ids.getVariables(), currConstVars, conditional); - + setStatementFormatType(os, conditional); target.setDimensionValueProperties(ids.getVariable(target.getName())); } - + else if (current instanceof AssignmentStatement){ AssignmentStatement as = (AssignmentStatement)current; - DataIdentifier target = as.getTarget(); + DataIdentifier target = as.getTarget(); Expression source = as.getSource(); - + if (source instanceof FunctionCallIdentifier) { ((FunctionCallIdentifier) source).validateExpression(dmlProg, ids.getVariables(),currConstVars, conditional); } else { if( MLContextProxy.isActive() ) MLContextProxy.setAppropriateVarsForRead(source, target._name); - + source.validateExpression(ids.getVariables(), currConstVars, conditional); } - + if (source instanceof DataExpression && ((DataExpression)source).getOpCode() == Expression.DataOp.READ) setStatementFormatType(as, conditional); - - // Handle const vars: (a) basic constant propagation, and (b) transitive constant propagation over assignments + + // Handle const vars: (a) basic constant propagation, and (b) transitive constant propagation over assignments if (target != null) { currConstVars.remove(target.getName()); if(source instanceof ConstIdentifier && !(target instanceof IndexedIdentifier)){ //basic @@ -596,10 +596,10 @@ public class StatementBlock extends LiveVariableAnalysis } } } - + if (source instanceof BuiltinFunctionExpression){ BuiltinFunctionExpression bife = (BuiltinFunctionExpression)source; - if ( bife.getOpCode() == Expression.BuiltinFunctionOp.NROW + if ( bife.getOpCode() == Expression.BuiltinFunctionOp.NROW || bife.getOpCode() == Expression.BuiltinFunctionOp.NCOL ) { DataIdentifier id = (DataIdentifier)bife.getFirstExpr(); @@ -610,15 +610,15 @@ public class StatementBlock extends LiveVariableAnalysis } IntIdentifier intid = null; if (bife.getOpCode() == Expression.BuiltinFunctionOp.NROW){ - intid = new IntIdentifier((currVal instanceof IndexedIdentifier)?((IndexedIdentifier)currVal).getOrigDim1():currVal.getDim1(), + intid = new IntIdentifier((currVal instanceof IndexedIdentifier)?((IndexedIdentifier)currVal).getOrigDim1():currVal.getDim1(), bife.getFilename(), bife.getBeginLine(), bife.getBeginColumn(), bife.getEndLine(), bife.getEndColumn()); } else { - intid = new IntIdentifier((currVal instanceof IndexedIdentifier)?((IndexedIdentifier)currVal).getOrigDim2():currVal.getDim2(), + intid = new IntIdentifier((currVal instanceof IndexedIdentifier)?((IndexedIdentifier)currVal).getOrigDim2():currVal.getDim2(), bife.getFilename(), bife.getBeginLine(), bife.getBeginColumn(), bife.getEndLine(), bife.getEndColumn()); } - - // handle case when nrow / ncol called on variable with size unknown (dims == -1) - // --> const prop NOT possible + + // handle case when nrow / ncol called on variable with size unknown (dims == -1) + // --> const prop NOT possible if (intid.getValue() != -1){ currConstVars.put(target.getName(), intid); } @@ -633,10 +633,10 @@ public class StatementBlock extends LiveVariableAnalysis if (source.getOutput() instanceof IndexedIdentifier){ target.setDimensions(source.getOutput().getDim1(), source.getOutput().getDim2()); } - + } // CASE: target is indexed identifier - else + else { // process the "target" being indexed DataIdentifier targetAsSeen = ids.getVariable(target.getName()); @@ -644,7 +644,7 @@ public class StatementBlock extends LiveVariableAnalysis target.raiseValidateError("cannot assign value to indexed identifier " + target.toString() + " without first initializing " + target.getName(), conditional); } target.setProperties(targetAsSeen); - + // process the expressions for the indexing if ( ((IndexedIdentifier)target).getRowLowerBound() != null ) ((IndexedIdentifier)target).getRowLowerBound().validateExpression(ids.getVariables(), currConstVars, conditional); @@ -654,71 +654,71 @@ public class StatementBlock extends LiveVariableAnalysis ((IndexedIdentifier)target).getColLowerBound().validateExpression(ids.getVariables(), currConstVars, conditional); if ( ((IndexedIdentifier)target).getColUpperBound() != null ) ((IndexedIdentifier)target).getColUpperBound().validateExpression(ids.getVariables(), currConstVars, conditional); - + // validate that LHS indexed identifier is being assigned a matrix value // if (source.getOutput().getDataType() != Expression.DataType.MATRIX){ // LOG.error(target.printErrorLocation() + "Indexed expression " + target.toString() + " can only be assigned matrix value"); // throw new LanguageException(target.printErrorLocation() + "Indexed expression " + target.toString() + " can only be assigned matrix value"); // } - + // validate that size of LHS index ranges is being assigned: // (a) a matrix value of same size as LHS // (b) singleton value (semantics: initialize enitre submatrix with this value) IndexPair targetSize = ((IndexedIdentifier)target).calculateIndexedDimensions(ids.getVariables(), currConstVars, conditional); - + if (targetSize._row >= 1 && source.getOutput().getDim1() > 1 && targetSize._row != source.getOutput().getDim1()){ - target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions " - + targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions " + target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions " + + targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions " + source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional); } - + if (targetSize._col >= 1 && source.getOutput().getDim2() > 1 && targetSize._col != source.getOutput().getDim2()){ - target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions " - + targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions " + target.raiseValidateError("Dimension mismatch. Indexed expression " + target.toString() + " can only be assigned matrix with dimensions " + + targetSize._row + " rows and " + targetSize._col + " cols. Attempted to assign matrix with dimensions " + source.getOutput().getDim1() + " rows and " + source.getOutput().getDim2() + " cols ", conditional); } - - ((IndexedIdentifier)target).setDimensions(targetSize._row, targetSize._col); + + ((IndexedIdentifier)target).setDimensions(targetSize._row, targetSize._col); } - + if (target != null) { ids.addVariable(target.getName(), target); } - + } - + else if (current instanceof MultiAssignmentStatement){ MultiAssignmentStatement mas = (MultiAssignmentStatement) current; - ArrayList<DataIdentifier> targetList = mas.getTargetList(); - + ArrayList<DataIdentifier> targetList = mas.getTargetList(); + // perform validation of source expression Expression source = mas.getSource(); /* - * MultiAssignmentStatments currently supports only External, + * MultiAssignmentStatments currently supports only External, * User-defined, and Multi-return Builtin function expressions */ - if (!(source instanceof DataIdentifier) + if (!(source instanceof DataIdentifier) || (source instanceof DataIdentifier && !((DataIdentifier)source).multipleReturns()) ) { //if (!(source instanceof FunctionCallIdentifier) ) { //|| !(source instanceof BuiltinFunctionExpression && ((BuiltinFunctionExpression)source).isMultiReturnBuiltinFunction()) ){ source.raiseValidateError("can only use user-defined functions with multi-assignment statement", conditional); } - + if ( source instanceof FunctionCallIdentifier) { FunctionCallIdentifier fci = (FunctionCallIdentifier)source; fci.validateExpression(dmlProg, ids.getVariables(), currConstVars, conditional); } - else if ( (source instanceof BuiltinFunctionExpression || source instanceof ParameterizedBuiltinFunctionExpression) + else if ( (source instanceof BuiltinFunctionExpression || source instanceof ParameterizedBuiltinFunctionExpression) && ((DataIdentifier)source).multipleReturns()) { source.validateExpression(mas, ids.getVariables(), currConstVars, conditional); } - else + else throw new LanguageException("Unexpected error."); - - + + if ( source instanceof FunctionCallIdentifier ) { for (int j =0; j< targetList.size(); j++){ - + DataIdentifier target = targetList.get(j); // set target properties (based on type info in function call statement return params) FunctionCallIdentifier fci = (FunctionCallIdentifier)source; @@ -746,7 +746,7 @@ public class StatementBlock extends LiveVariableAnalysis } } } - + else if(current instanceof ForStatement || current instanceof IfStatement || current instanceof WhileStatement ){ raiseValidateError("control statement (WhileStatement, IfStatement, ForStatement) should not be in generic statement block. Likely a parsing error", conditional); } @@ -762,31 +762,31 @@ public class StatementBlock extends LiveVariableAnalysis } } - + // no work to perform for PathStatement or ImportStatement else if (current instanceof PathStatement){} else if (current instanceof ImportStatement){} - - + + else { raiseValidateError("cannot process statement of type " + current.getClass().getSimpleName(), conditional); } - + } // end for (Statement current : _statements){ _constVarsOut.putAll(currConstVars); return ids; } - - public void setStatementFormatType(OutputStatement s, boolean conditionalValidate) + + public void setStatementFormatType(OutputStatement s, boolean conditionalValidate) throws LanguageException, ParseException { //case of specified format parameter if (s.getExprParam(DataExpression.FORMAT_TYPE)!= null ) - { - Expression formatTypeExpr = s.getExprParam(DataExpression.FORMAT_TYPE); + { + Expression formatTypeExpr = s.getExprParam(DataExpression.FORMAT_TYPE); if (!(formatTypeExpr instanceof StringIdentifier)){ - raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + " can only be a string with one of following values: binary, text, mm, csv.", false, LanguageErrorCodes.INVALID_PARAMETERS); } String ft = formatTypeExpr.toString(); @@ -798,33 +798,33 @@ public class StatementBlock extends LiveVariableAnalysis s.getIdentifier().setFormatType(FormatType.MM); } else if (ft.equalsIgnoreCase(DataExpression.FORMAT_TYPE_VALUE_CSV)){ s.getIdentifier().setFormatType(FormatType.CSV); - } else{ - raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + } else{ + raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + " can only be a string with one of following values: binary, text, mm, csv; invalid format: '"+ft+"'.", false, LanguageErrorCodes.INVALID_PARAMETERS); } - } + } //case of unspecified format parameter, use default - else + else { - s.addExprParam(DataExpression.FORMAT_TYPE, new StringIdentifier(FormatType.TEXT.toString(), + s.addExprParam(DataExpression.FORMAT_TYPE, new StringIdentifier(FormatType.TEXT.toString(), s.getFilename(), s.getBeginLine(), s.getBeginColumn(), s.getEndLine(), s.getEndColumn()), true); s.getIdentifier().setFormatType(FormatType.TEXT); } } - - public void setStatementFormatType(AssignmentStatement s, boolean conditionalValidate) + + public void setStatementFormatType(AssignmentStatement s, boolean conditionalValidate) throws LanguageException, ParseException { - + if (!(s.getSource() instanceof DataExpression)) return; DataExpression dataExpr = (DataExpression)s.getSource(); - + if (dataExpr.getVarParam(DataExpression.FORMAT_TYPE)!= null ){ - - Expression formatTypeExpr = dataExpr.getVarParam(DataExpression.FORMAT_TYPE); + + Expression formatTypeExpr = dataExpr.getVarParam(DataExpression.FORMAT_TYPE); if (!(formatTypeExpr instanceof StringIdentifier)){ - raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + " can only be a string with one of following values: binary, text", conditionalValidate, LanguageErrorCodes.INVALID_PARAMETERS); } String ft = formatTypeExpr.toString(); @@ -836,8 +836,8 @@ public class StatementBlock extends LiveVariableAnalysis s.getTarget().setFormatType(FormatType.MM); } else if (ft.equalsIgnoreCase(DataExpression.FORMAT_TYPE_VALUE_CSV)){ s.getTarget().setFormatType(FormatType.CSV); - } else{ - raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + } else{ + raiseValidateError("IO statement parameter " + DataExpression.FORMAT_TYPE + " can only be a string with one of following values: binary, text, mm, csv", conditionalValidate, LanguageErrorCodes.INVALID_PARAMETERS); } } else { @@ -847,33 +847,33 @@ public class StatementBlock extends LiveVariableAnalysis } } - + /** * For each statement: - * + * * gen rule: for each variable read in current statement but not updated in any PRIOR statement, add to gen * Handles case where variable both read and updated in same statement (i = i + 1, i needs to be added to gen) - * + * * kill rule: for each variable updated in current statement but not read in this or any PRIOR statement, - * add to kill. - * + * add to kill. + * */ @Override public VariableSet initializeforwardLV(VariableSet activeIn) throws LanguageException { - + for (Statement s : _statements){ s.initializeforwardLV(activeIn); VariableSet read = s.variablesRead(); VariableSet updated = s.variablesUpdated(); - + if (s instanceof WhileStatement || s instanceof IfStatement || s instanceof ForStatement){ raiseValidateError("control statement (while / for / if) cannot be in generic statement block", false); } - + if (read != null){ - // for each variable read in this statement but not updated in + // for each variable read in this statement but not updated in // any prior statement, add to sb._gen - + for (String var : read.getVariableNames()) { if (!_updated.containsVariable(var)) { _gen.addVariable(var, read.getVariable(var)); @@ -886,13 +886,13 @@ public class StatementBlock extends LiveVariableAnalysis if (updated != null) { // for each updated variable that is not read - for (String var : updated.getVariableNames()) + for (String var : updated.getVariableNames()) { //NOTE MB: always add updated vars to kill (in order to prevent side effects //of implicitly updated statistics over common data identifiers, propagated from //downstream operators to its inputs due to 'livein = gen \cup (liveout-kill))'. _kill.addVariable(var, _updated.getVariable(var)); - + //if (!_read.containsVariable(var)) { // _kill.addVariable(var, _updated.getVariable(var)); //} @@ -904,9 +904,9 @@ public class StatementBlock extends LiveVariableAnalysis _liveOut.addVariables(_updated); return _liveOut; } - + @Override - public VariableSet initializebackwardLV(VariableSet loPassed) + public VariableSet initializebackwardLV(VariableSet loPassed) throws LanguageException { int numStatements = _statements.size(); @@ -914,74 +914,74 @@ public class StatementBlock extends LiveVariableAnalysis for (int i = numStatements-1; i>=0; i--){ lo = _statements.get(i).initializebackwardLV(lo); } - + return new VariableSet(lo); } public HashMap<String, ConstIdentifier> getConstIn(){ return _constVarsIn; } - + public HashMap<String, ConstIdentifier> getConstOut(){ return _constVarsOut; } - - - public VariableSet analyze(VariableSet loPassed) + + + public VariableSet analyze(VariableSet loPassed) throws LanguageException{ - + VariableSet candidateLO = new VariableSet(); candidateLO.addVariables(loPassed); //candidateLO.addVariables(_gen); - + VariableSet origLiveOut = new VariableSet(); origLiveOut.addVariables(_liveOut); - + _liveOut = new VariableSet(); for (String name : candidateLO.getVariableNames()){ if (origLiveOut.containsVariable(name)){ _liveOut.addVariable(name, candidateLO.getVariable(name)); } } - + initializebackwardLV(_liveOut); - + _liveIn = new VariableSet(); _liveIn.addVariables(_liveOut); _liveIn.removeVariables(_kill); _liveIn.addVariables(_gen); - + VariableSet liveInReturn = new VariableSet(); liveInReturn.addVariables(_liveIn); return liveInReturn; } - + /////////////////////////////////////////////////////////////// // validate error handling (consistent for all expressions) - - public void raiseValidateError( String msg, boolean conditional ) + + public void raiseValidateError( String msg, boolean conditional ) throws LanguageException { raiseValidateError(msg, conditional, null); } - - public void raiseValidateError( String msg, boolean conditional, String errorCode ) + + public void raiseValidateError( String msg, boolean conditional, String errorCode ) throws LanguageException { if( conditional ) //warning if conditional { String fullMsg = this.printWarningLocation() + msg; - + LOG.warn( fullMsg ); } else //error and exception if unconditional { String fullMsg = this.printErrorLocation() + msg; - - //LOG.error( fullMsg ); //no redundant error + + //LOG.error( fullMsg ); //no redundant error if( errorCode != null ) throw new LanguageException( fullMsg, errorCode ); - else + else throw new LanguageException( fullMsg ); } } @@ -992,17 +992,17 @@ public class StatementBlock extends LiveVariableAnalysis private String _filename = "MAIN SCRIPT"; private int _beginLine = 0, _beginColumn = 0; private int _endLine = 0, _endColumn = 0; - + public void setFilename (String fname) { _filename = fname; } public void setBeginLine(int passed) { _beginLine = passed; } public void setBeginColumn(int passed) { _beginColumn = passed; } public void setEndLine(int passed) { _endLine = passed; } public void setEndColumn(int passed) { _endColumn = passed; } - + public void setAllPositions(String fname, int blp, int bcp, int elp, int ecp){ _filename = fname; - _beginLine = blp; - _beginColumn = bcp; + _beginLine = blp; + _beginColumn = bcp; _endLine = elp; _endColumn = ecp; } @@ -1012,39 +1012,39 @@ public class StatementBlock extends LiveVariableAnalysis public int getBeginColumn() { return _beginColumn; } public int getEndLine() { return _endLine; } public int getEndColumn() { return _endColumn; } - + public String printErrorLocation(){ return "ERROR: " + _filename + " -- line " + _beginLine + ", column " + _beginColumn + " -- "; } - + public String printBlockErrorLocation(){ return "ERROR: " + _filename + " -- statement block between lines " + _beginLine + " and " + _endLine + " -- "; } - + public String printWarningLocation(){ return "WARNING: " + _filename + " -- line " + _beginLine + ", column " + _beginColumn + " -- "; } - + ///////// // materialized hops recompilation / updateinplace flags //// - + public void updateRecompilationFlag() throws HopsException { - _requiresRecompile = ConfigurationManager.isDynamicRecompilation() + _requiresRecompile = ConfigurationManager.isDynamicRecompilation() && Recompiler.requiresRecompilation(get_hops()); } - + public boolean requiresRecompilation() { return _requiresRecompile; } - + public ArrayList<String> getUpdateInPlaceVars() { return _updateInPlaceVars; } - + public void setUpdateInPlaceVars( ArrayList<String> vars ) { _updateInPlaceVars = vars; } - + } // end class http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java index 71fb1c2..5987d51 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -24,14 +24,13 @@ import java.util.ArrayList; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.MLContextProxy; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.Hop; import org.apache.sysml.hops.OptimizerUtils; import org.apache.sysml.hops.recompile.Recompiler; import org.apache.sysml.lops.Lop; -import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.parser.StatementBlock; import org.apache.sysml.runtime.DMLRuntimeException; import org.apache.sysml.runtime.DMLScriptException; import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; @@ -47,29 +46,30 @@ import org.apache.sysml.runtime.instructions.cp.ScalarObject; import org.apache.sysml.runtime.instructions.cp.StringObject; import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.utils.MLContextProxy; import org.apache.sysml.utils.Statistics; import org.apache.sysml.yarn.DMLAppMasterUtils; -public class ProgramBlock -{ +public class ProgramBlock +{ protected static final Log LOG = LogFactory.getLog(ProgramBlock.class.getName()); private static final boolean CHECK_MATRIX_SPARSITY = false; - + protected Program _prog; // pointer to Program this ProgramBlock is part of protected ArrayList<Instruction> _inst; - + //additional attributes for recompile protected StatementBlock _sb = null; protected long _tid = 0; //by default _t0 - - - public ProgramBlock(Program prog) { + + + public ProgramBlock(Program prog) { _prog = prog; _inst = new ArrayList<Instruction>(); } - - + + //////////////////////////////////////////////// // getters, setters and similar functionality //////////////////////////////////////////////// @@ -77,15 +77,15 @@ public class ProgramBlock public Program getProgram(){ return _prog; } - + public void setProgram(Program prog){ _prog = prog; } - + public StatementBlock getStatementBlock(){ return _sb; } - + public void setStatementBlock( StatementBlock sb ){ _sb = sb; } @@ -97,23 +97,23 @@ public class ProgramBlock public Instruction getInstruction(int i) { return _inst.get(i); } - + public void setInstructions( ArrayList<Instruction> inst ) { _inst = inst; } - + public void addInstruction(Instruction inst) { _inst.add(inst); } - + public void addInstructions(ArrayList<Instruction> inst) { _inst.addAll(inst); } - + public int getNumInstructions() { return _inst.size(); } - + public void setThreadID( long id ){ _tid = id; } @@ -125,29 +125,29 @@ public class ProgramBlock /** * Executes this program block (incl recompilation if required). - * + * * @param ec execution context * @throws DMLRuntimeException if DMLRuntimeException occurs */ - public void execute(ExecutionContext ec) - throws DMLRuntimeException + public void execute(ExecutionContext ec) + throws DMLRuntimeException { ArrayList<Instruction> tmp = _inst; //dynamically recompile instructions if enabled and required - try + try { if( DMLScript.isActiveAM() ) //set program block specific remote memory DMLAppMasterUtils.setupProgramBlockRemoteMaxMemory(this); - + long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - if( ConfigurationManager.isDynamicRecompilation() - && _sb != null + if( ConfigurationManager.isDynamicRecompilation() + && _sb != null && _sb.requiresRecompilation() ) { tmp = Recompiler.recompileHopsDag( _sb, _sb.get_hops(), ec.getVariables(), null, false, true, _tid); - + if( MLContextProxy.isActive() ) tmp = MLContextProxy.performCleanupAfterRecompilation(tmp); } @@ -162,14 +162,14 @@ public class ProgramBlock { throw new DMLRuntimeException("Unable to recompile program block.", ex); } - + //actual instruction execution executeInstructions(tmp, ec); } - + /** * Executes given predicate instructions (incl recompilation if required) - * + * * @param inst list of instructions * @param hops high-level operator * @param requiresRecompile true if requires recompile @@ -178,15 +178,15 @@ public class ProgramBlock * @return scalar object * @throws DMLRuntimeException if DMLRuntimeException occurs */ - public ScalarObject executePredicate(ArrayList<Instruction> inst, Hop hops, boolean requiresRecompile, ValueType retType, ExecutionContext ec) + public ScalarObject executePredicate(ArrayList<Instruction> inst, Hop hops, boolean requiresRecompile, ValueType retType, ExecutionContext ec) throws DMLRuntimeException { ArrayList<Instruction> tmp = inst; - + //dynamically recompile instructions if enabled and required try { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - if( ConfigurationManager.isDynamicRecompilation() + if( ConfigurationManager.isDynamicRecompilation() && requiresRecompile ) { tmp = Recompiler.recompileHopsDag( @@ -203,27 +203,27 @@ public class ProgramBlock { throw new DMLRuntimeException("Unable to recompile predicate instructions.", ex); } - + //actual instruction execution return executePredicateInstructions(tmp, retType, ec); } - protected void executeInstructions(ArrayList<Instruction> inst, ExecutionContext ec) - throws DMLRuntimeException + protected void executeInstructions(ArrayList<Instruction> inst, ExecutionContext ec) + throws DMLRuntimeException { - for (int i = 0; i < inst.size(); i++) + for (int i = 0; i < inst.size(); i++) { //indexed access required due to dynamic add Instruction currInst = inst.get(i); - + //execute instruction ec.updateDebugState(i); executeSingleInstruction(currInst, ec); } } - protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst, ValueType retType, ExecutionContext ec) - throws DMLRuntimeException + protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst, ValueType retType, ExecutionContext ec) + throws DMLRuntimeException { ScalarObject ret = null; String retName = null; @@ -232,24 +232,24 @@ public class ProgramBlock for (int i = 0; i < inst.size(); i++) { //indexed access required due to debug mode - Instruction currInst = inst.get(i); + Instruction currInst = inst.get(i); if( !isRemoveVariableInstruction(currInst) ) { //execute instruction ec.updateDebugState(i); executeSingleInstruction(currInst, ec); - + //get last return name if(currInst instanceof ComputationCPInstruction ) - retName = ((ComputationCPInstruction) currInst).getOutputVariableName(); + retName = ((ComputationCPInstruction) currInst).getOutputVariableName(); else if(currInst instanceof VariableCPInstruction && ((VariableCPInstruction)currInst).getOutputVariableName()!=null) retName = ((VariableCPInstruction)currInst).getOutputVariableName(); } } - + //get return value TODO: how do we differentiate literals and variables? ret = (ScalarObject) ec.getScalarInput(retName, retType, false); - + //execute rmvar instructions for (int i = 0; i < inst.size(); i++) { //indexed access required due to debug mode @@ -259,7 +259,7 @@ public class ProgramBlock executeSingleInstruction(currInst, ec); } } - + //check and correct scalar ret type (incl save double to int) if( ret.getValueType() != retType ) switch( retType ) { @@ -270,51 +270,51 @@ public class ProgramBlock default: //do nothing } - + return ret; } - private void executeSingleInstruction( Instruction currInst, ExecutionContext ec ) + private void executeSingleInstruction( Instruction currInst, ExecutionContext ec ) throws DMLRuntimeException - { - try - { + { + try + { // start time measurement for statistics - long t0 = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? + long t0 = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? System.nanoTime() : 0; - + // pre-process instruction (debug state, inst patching, listeners) Instruction tmp = currInst.preprocessInstruction( ec ); - + // process actual instruction tmp.processInstruction( ec ); - - // post-process instruction (debug) + + // post-process instruction (debug) tmp.postprocessInstruction( ec ); - + // maintain aggregate statistics if( DMLScript.STATISTICS) { Statistics.maintainCPHeavyHitters( tmp.getExtendedOpcode(), System.nanoTime()-t0); } - + // optional trace information (instruction and runtime) if( LOG.isTraceEnabled() ) { long t1 = System.nanoTime(); String time = String.format("%.3f",((double)t1-t0)/1000000000); LOG.trace("Instruction: "+ tmp + " (executed in " + time + "s)."); } - - // optional check for correct nnz and sparse/dense representation of all + + // optional check for correct nnz and sparse/dense representation of all // variables in symbol table (for tracking source of wrong representation) if( CHECK_MATRIX_SPARSITY ) { checkSparsity( tmp, ec.getVariables() ); - } + } } catch (Exception e) { if (!DMLScript.ENABLE_DEBUG_MODE) { - if ( e instanceof DMLScriptException) + if ( e instanceof DMLScriptException) throw (DMLScriptException)e; else throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating instruction: " + currInst.toString() , e); @@ -325,12 +325,12 @@ public class ProgramBlock } } - protected UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext ec, long tid) + protected UpdateType[] prepareUpdateInPlaceVariables(ExecutionContext ec, long tid) throws DMLRuntimeException { if( _sb == null || _sb.getUpdateInPlaceVars().isEmpty() ) return null; - + ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); UpdateType[] flags = new UpdateType[varnames.size()]; for( int i=0; i<flags.length; i++ ) @@ -341,27 +341,27 @@ public class ProgramBlock //create deep copy if required and if it fits in thread-local mem budget if( flags[i]==UpdateType.COPY && OptimizerUtils.getLocalMemBudget()/2 > OptimizerUtils.estimateSizeExactSparsity(mo.getMatrixCharacteristics())) { - MatrixObject moNew = new MatrixObject(mo); - MatrixBlock mbVar = mo.acquireRead(); - moNew.acquireModify( !mbVar.isInSparseFormat() ? new MatrixBlock(mbVar) : + MatrixObject moNew = new MatrixObject(mo); + MatrixBlock mbVar = mo.acquireRead(); + moNew.acquireModify( !mbVar.isInSparseFormat() ? new MatrixBlock(mbVar) : new MatrixBlock(mbVar, MatrixBlock.DEFAULT_INPLACE_SPARSEBLOCK, true) ); moNew.setFileName(mo.getFileName()+Lop.UPDATE_INPLACE_PREFIX+tid); mo.release(); - moNew.release(); + moNew.release(); moNew.setUpdateType(UpdateType.INPLACE); ec.setVariable(varname, moNew); } } - + return flags; } - protected void resetUpdateInPlaceVariableFlags(ExecutionContext ec, UpdateType[] flags) + protected void resetUpdateInPlaceVariableFlags(ExecutionContext ec, UpdateType[] flags) throws DMLRuntimeException { if( flags == null ) return; - + //reset update-in-place flag to pre-loop status ArrayList<String> varnames = _sb.getUpdateInPlaceVars(); for( int i=0; i<varnames.size(); i++ ) @@ -387,7 +387,7 @@ public class ProgramBlock MatrixObject mo = (MatrixObject)dat; if( mo.isDirty() && !mo.isPartitioned() ) { - MatrixBlock mb = mo.acquireRead(); + MatrixBlock mb = mo.acquireRead(); boolean sparse1 = mb.isInSparseFormat(); long nnz1 = mb.getNonZeros(); synchronized( mb ) { //potential state change @@ -397,33 +397,33 @@ public class ProgramBlock boolean sparse2 = mb.isInSparseFormat(); long nnz2 = mb.getNonZeros(); mo.release(); - + if( nnz1 != nnz2 ) throw new DMLRuntimeException("Matrix nnz meta data was incorrect: ("+varname+", actual="+nnz1+", expected="+nnz2+", inst="+lastInst+")"); - + if( sparse1 != sparse2 ) - throw new DMLRuntimeException("Matrix was in wrong data representation: ("+varname+", actual="+sparse1+", expected="+sparse2 + + throw new DMLRuntimeException("Matrix was in wrong data representation: ("+varname+", actual="+sparse1+", expected="+sparse2 + ", nrow="+mb.getNumRows()+", ncol="+mb.getNumColumns()+", nnz="+nnz1+", inst="+lastInst+")"); } } } } - + /////////////////////////////////////////////////////////////////////////// // store position information for program blocks /////////////////////////////////////////////////////////////////////////// - + public int _beginLine, _beginColumn; public int _endLine, _endColumn; - + public void setBeginLine(int passed) { _beginLine = passed; } public void setBeginColumn(int passed) { _beginColumn = passed; } public void setEndLine(int passed) { _endLine = passed; } public void setEndColumn(int passed) { _endColumn = passed; } - + public void setAllPositions(int blp, int bcp, int elp, int ecp){ - _beginLine = blp; - _beginColumn = bcp; + _beginLine = blp; + _beginColumn = bcp; _endLine = elp; _endColumn = ecp; } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java index 06a2005..c67cdd8 100644 --- a/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysml/runtime/controlprogram/context/SparkExecutionContext.java @@ -38,7 +38,7 @@ import org.apache.spark.storage.RDDInfo; import org.apache.spark.storage.StorageLevel; import org.apache.spark.util.LongAccumulator; import org.apache.sysml.api.DMLScript; -import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.api.mlcontext.MLContext; import org.apache.sysml.api.mlcontext.MLContextUtil; import org.apache.sysml.conf.ConfigurationManager; import org.apache.sysml.hops.OptimizerUtils; @@ -75,6 +75,7 @@ import org.apache.sysml.runtime.matrix.data.SparseBlock; import org.apache.sysml.runtime.matrix.mapred.MRJobConfiguration; import org.apache.sysml.runtime.util.MapReduceTool; import org.apache.sysml.runtime.util.UtilFunctions; +import org.apache.sysml.utils.MLContextProxy; import org.apache.sysml.utils.Statistics; import scala.Tuple2; @@ -189,13 +190,12 @@ public class SparkExecutionContext extends ExecutionContext //create a default spark context (master, appname, etc refer to system properties //as given in the spark configuration or during spark-submit) - Object mlCtxObj = MLContextProxy.getActiveMLContext(); + MLContext mlCtxObj = MLContextProxy.getActiveMLContext(); if(mlCtxObj != null) { // This is when DML is called through spark shell // Will clean the passing of static variables later as this involves minimal change to DMLScript - org.apache.sysml.api.mlcontext.MLContext mlCtx = (org.apache.sysml.api.mlcontext.MLContext) mlCtxObj; - _spctx = MLContextUtil.getJavaSparkContext(mlCtx); + _spctx = MLContextUtil.getJavaSparkContext(mlCtxObj); } else { http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/f3b1aebc/src/main/java/org/apache/sysml/utils/MLContextProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/utils/MLContextProxy.java b/src/main/java/org/apache/sysml/utils/MLContextProxy.java new file mode 100644 index 0000000..825d42a --- /dev/null +++ b/src/main/java/org/apache/sysml/utils/MLContextProxy.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.utils; + +import java.util.ArrayList; + +import org.apache.sysml.api.mlcontext.MLContext; +import org.apache.sysml.api.mlcontext.MLContextException; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.LanguageException; +import org.apache.sysml.runtime.instructions.Instruction; + +/** + * The purpose of this proxy is to shield systemml internals from direct access to MLContext + * which would try to load spark libraries and hence fail if these are not available. This + * indirection is much more efficient than catching NoClassDefFoundErrors for every access + * to MLContext (e.g., on each recompile). + * + */ +public class MLContextProxy +{ + + private static boolean _active = false; + + public static void setActive(boolean flag) { + _active = flag; + } + + public static boolean isActive() { + return _active; + } + + public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp) + { + return MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp); + } + + public static void setAppropriateVarsForRead(Expression source, String targetname) + throws LanguageException + { + MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname); + } + + public static MLContext getActiveMLContext() { + return MLContext.getActiveMLContext(); + } + + public static MLContext getActiveMLContextForAPI() { + if (MLContext.getActiveMLContext() != null) { + return MLContext.getActiveMLContext(); + } + throw new MLContextException("No MLContext object is currently active. Have you created one? " + + "Hint: in Scala, 'val ml = new MLContext(sc)'", true); + } + +}