[SYSTEMML-593] MLContext redesign Closes #199.
Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/457bbd3a Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/457bbd3a Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/457bbd3a Branch: refs/heads/master Commit: 457bbd3a4aca2c75163f4cbaed3faa2a9cb14d72 Parents: 873bae7 Author: Deron Eriksson <[email protected]> Authored: Thu Jul 28 17:11:40 2016 -0700 Committer: Deron Eriksson <[email protected]> Committed: Thu Jul 28 17:11:40 2016 -0700 ---------------------------------------------------------------------- pom.xml | 1 + .../java/org/apache/sysml/api/DMLScript.java | 2 +- .../java/org/apache/sysml/api/MLContext.java | 4 + .../org/apache/sysml/api/MLContextProxy.java | 55 +- .../sysml/api/mlcontext/BinaryBlockMatrix.java | 148 ++ .../apache/sysml/api/mlcontext/MLContext.java | 505 ++++++ .../api/mlcontext/MLContextConversionUtil.java | 720 ++++++++ .../sysml/api/mlcontext/MLContextException.java | 47 + .../sysml/api/mlcontext/MLContextUtil.java | 844 +++++++++ .../apache/sysml/api/mlcontext/MLResults.java | 1299 +++++++++++++ .../org/apache/sysml/api/mlcontext/Matrix.java | 141 ++ .../sysml/api/mlcontext/MatrixFormat.java | 39 + .../sysml/api/mlcontext/MatrixMetadata.java | 522 ++++++ .../org/apache/sysml/api/mlcontext/Script.java | 652 +++++++ .../sysml/api/mlcontext/ScriptExecutor.java | 624 +++++++ .../sysml/api/mlcontext/ScriptFactory.java | 422 +++++ .../apache/sysml/api/mlcontext/ScriptType.java | 65 + .../context/SparkExecutionContext.java | 47 +- .../instructions/spark/SPInstruction.java | 52 +- .../spark/functions/SparkListener.java | 38 +- .../spark/utils/RDDConverterUtilsExt.java | 4 +- .../integration/mlcontext/MLContextTest.java | 1713 ++++++++++++++++++ .../org/apache/sysml/api/mlcontext/1234.csv | 2 + .../org/apache/sysml/api/mlcontext/1234.csv.mtd | 13 + .../apache/sysml/api/mlcontext/hello-world.dml | 22 + .../sysml/api/mlcontext/hello-world.pydml | 22 + .../sysml/api/mlcontext/one-two-three-four.csv | 2 + .../api/mlcontext/one-two-three-four.csv.mtd | 5 + .../integration/mlcontext/ZPackageSuite.java | 37 + 29 files changed, 7991 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 9574679..f7ec5b7 100644 --- a/pom.xml +++ b/pom.xml @@ -394,6 +394,7 @@ <include>**/integration/functions/gdfo/*Suite.java</include> <include>**/integration/functions/sparse/*Suite.java</include> <include>**/integration/functions/**/*Test*.java</include> + <include>**/integration/mlcontext/*Suite.java</include> <include>**/integration/scalability/**/*Test.java</include> </includes> http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/DMLScript.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java b/src/main/java/org/apache/sysml/api/DMLScript.java index 814bcb8..3d76273 100644 --- a/src/main/java/org/apache/sysml/api/DMLScript.java +++ b/src/main/java/org/apache/sysml/api/DMLScript.java @@ -777,7 +777,7 @@ public class DMLScript * @throws DMLRuntimeException * */ - static void initHadoopExecution( DMLConfig config ) + public static void initHadoopExecution( DMLConfig config ) throws IOException, ParseException, DMLRuntimeException { //check security aspects http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLContext.java b/src/main/java/org/apache/sysml/api/MLContext.java index 2600c35..a03c8b7 100644 --- a/src/main/java/org/apache/sysml/api/MLContext.java +++ b/src/main/java/org/apache/sysml/api/MLContext.java @@ -87,6 +87,10 @@ import org.apache.sysml.utils.Explain.ExplainCounts; import org.apache.sysml.utils.Statistics; /** + * The MLContext API has been redesigned and this API will be deprecated. + * Please migrate to {@link org.apache.sysml.api.mlcontext.MLContext}. + * <p> + * * MLContext is useful for passing RDDs as input/output to SystemML. This API avoids the need to read/write * from HDFS (which is another way to pass inputs to SystemML). * <p> http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/MLContextProxy.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/MLContextProxy.java b/src/main/java/org/apache/sysml/api/MLContextProxy.java index ee16690..f8f31d6 100644 --- a/src/main/java/org/apache/sysml/api/MLContextProxy.java +++ b/src/main/java/org/apache/sysml/api/MLContextProxy.java @@ -61,8 +61,10 @@ public class MLContextProxy */ public static ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> tmp) { - if(MLContext.getActiveMLContext() != null) { - return MLContext.getActiveMLContext().performCleanupAfterRecompilation(tmp); + if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) { + return org.apache.sysml.api.MLContext.getActiveMLContext().performCleanupAfterRecompilation(tmp); + } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) { + return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().performCleanupAfterRecompilation(tmp); } return tmp; } @@ -76,28 +78,55 @@ public class MLContextProxy public static void setAppropriateVarsForRead(Expression source, String targetname) throws LanguageException { - MLContext mlContext = MLContext.getActiveMLContext(); - if(mlContext != null) { - mlContext.setAppropriateVarsForRead(source, targetname); + if(org.apache.sysml.api.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.MLContext.getActiveMLContext().setAppropriateVarsForRead(source, targetname); + } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext().getInternalProxy().setAppropriateVarsForRead(source, targetname); } } - public static MLContext getActiveMLContext() { - return MLContext.getActiveMLContext(); + public static Object getActiveMLContext() { + if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) { + return org.apache.sysml.api.MLContext.getActiveMLContext(); + } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) { + return org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext(); + } else { + return null; + } + } public static void setInstructionForMonitoring(Instruction inst) { Location loc = inst.getLocation(); - MLContext mlContext = MLContext.getActiveMLContext(); - if(loc != null && mlContext != null && mlContext.getMonitoringUtil() != null) { - mlContext.getMonitoringUtil().setInstructionLocation(loc, inst); + if (loc == null) { + return; + } + + if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.MLContext mlContext = org.apache.sysml.api.MLContext.getActiveMLContext(); + if(mlContext.getMonitoringUtil() != null) { + mlContext.getMonitoringUtil().setInstructionLocation(loc, inst); + } + } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.mlcontext.MLContext mlContext = org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext(); + if(mlContext.getSparkMonitoringUtil() != null) { + mlContext.getSparkMonitoringUtil().setInstructionLocation(loc, inst); + } } } public static void addRDDForInstructionForMonitoring(SPInstruction inst, Integer rddID) { - MLContext mlContext = MLContext.getActiveMLContext(); - if(mlContext != null && mlContext.getMonitoringUtil() != null) { - mlContext.getMonitoringUtil().addRDDForInstruction(inst, rddID); + + if (org.apache.sysml.api.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.MLContext mlContext = org.apache.sysml.api.MLContext.getActiveMLContext(); + if(mlContext.getMonitoringUtil() != null) { + mlContext.getMonitoringUtil().addRDDForInstruction(inst, rddID); + } + } else if (org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext() != null) { + org.apache.sysml.api.mlcontext.MLContext mlContext = org.apache.sysml.api.mlcontext.MLContext.getActiveMLContext(); + if(mlContext.getSparkMonitoringUtil() != null) { + mlContext.getSparkMonitoringUtil().addRDDForInstruction(inst, rddID); + } } } http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java new file mode 100644 index 0000000..8c9f923 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/BinaryBlockMatrix.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.sql.DataFrame; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; + +/** + * BinaryBlockMatrix stores data as a SystemML binary-block representation. + * + */ +public class BinaryBlockMatrix { + + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks; + MatrixMetadata matrixMetadata; + + /** + * Convert a Spark DataFrame to a SystemML binary-block representation. + * + * @param dataFrame + * the Spark DataFrame + * @param matrixMetadata + * matrix metadata, such as number of rows and columns + */ + public BinaryBlockMatrix(DataFrame dataFrame, MatrixMetadata matrixMetadata) { + this.matrixMetadata = matrixMetadata; + binaryBlocks = MLContextConversionUtil.dataFrameToBinaryBlocks(dataFrame, matrixMetadata); + } + + /** + * Convert a Spark DataFrame to a SystemML binary-block representation, + * specifying the number of rows and columns. + * + * @param dataFrame + * the Spark DataFrame + * @param numRows + * the number of rows + * @param numCols + * the number of columns + */ + public BinaryBlockMatrix(DataFrame dataFrame, long numRows, long numCols) { + this(dataFrame, new MatrixMetadata(numRows, numCols, MLContextUtil.defaultBlockSize(), + MLContextUtil.defaultBlockSize())); + } + + /** + * Convert a Spark DataFrame to a SystemML binary-block representation. + * + * @param dataFrame + * the Spark DataFrame + */ + public BinaryBlockMatrix(DataFrame dataFrame) { + this(dataFrame, new MatrixMetadata()); + } + + /** + * Create a BinaryBlockMatrix, specifying the SystemML binary-block matrix + * and its metadata. + * + * @param binaryBlocks + * the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix + * @param matrixCharacteristics + * the matrix metadata as {@code MatrixCharacteristics} + */ + public BinaryBlockMatrix(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks, + MatrixCharacteristics matrixCharacteristics) { + this.binaryBlocks = binaryBlocks; + this.matrixMetadata = new MatrixMetadata(matrixCharacteristics); + } + + /** + * Obtain a SystemML binary-block matrix as a + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} + * + * @return the SystemML binary-block matrix + */ + public JavaPairRDD<MatrixIndexes, MatrixBlock> getBinaryBlocks() { + return binaryBlocks; + } + + /** + * Obtain the SystemML binary-block matrix characteristics + * + * @return the matrix metadata as {@code MatrixCharacteristics} + */ + public MatrixCharacteristics getMatrixCharacteristics() { + return matrixMetadata.asMatrixCharacteristics(); + } + + /** + * Obtain the SystemML binary-block matrix metadata + * + * @return the matrix metadata as {@code MatrixMetadata} + */ + public MatrixMetadata getMatrixMetadata() { + return matrixMetadata; + } + + /** + * Set the SystemML binary-block matrix metadata + * + * @param matrixMetadata + * the matrix metadata + */ + public void setMatrixMetadata(MatrixMetadata matrixMetadata) { + this.matrixMetadata = matrixMetadata; + } + + /** + * Set the SystemML binary-block matrix as a + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} + * + * @param binaryBlocks + * the SystemML binary-block matrix + */ + public void setBinaryBlocks(JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks) { + this.binaryBlocks = binaryBlocks; + } + + @Override + public String toString() { + if (matrixMetadata != null) { + return matrixMetadata.toString(); + } else { + return super.toString(); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java new file mode 100644 index 0000000..05deec2 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContext.java @@ -0,0 +1,505 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.util.ArrayList; +import java.util.Date; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.sysml.api.DMLScript; +import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; +import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.api.monitoring.SparkMonitoringUtil; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.conf.DMLConfig; +import org.apache.sysml.parser.DataExpression; +import org.apache.sysml.parser.Expression; +import org.apache.sysml.parser.IntIdentifier; +import org.apache.sysml.parser.StringIdentifier; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.caching.CacheableData; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.Instruction; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.instructions.cp.VariableCPInstruction; +import org.apache.sysml.runtime.instructions.spark.functions.SparkListener; +import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; +import org.apache.sysml.runtime.matrix.data.OutputInfo; + +/** + * The MLContext API offers programmatic access to SystemML on Spark from + * languages such as Scala, Java, and Python. + * + */ +public class MLContext { + + /** + * Minimum Spark version supported by SystemML. + */ + public static final String SYSTEMML_MINIMUM_SPARK_VERSION = "1.4.0"; + + /** + * SparkContext object. + */ + private SparkContext sc = null; + + /** + * SparkMonitoringUtil monitors SystemML performance on Spark. + */ + private SparkMonitoringUtil sparkMonitoringUtil = null; + + /** + * Reference to the currently executing script. + */ + private Script executingScript = null; + + /** + * The currently active MLContext. + */ + private static MLContext activeMLContext = null; + + /** + * Contains cleanup methods used by MLContextProxy. + */ + private InternalProxy internalProxy = new InternalProxy(); + + /** + * Whether or not an explanation of the DML/PYDML program should be output + * to standard output. + */ + private boolean explain = false; + + /** + * Whether or not statistics of the DML/PYDML program execution should be + * output to standard output. + */ + private boolean statistics = false; + + private List<String> scriptHistoryStrings = new ArrayList<String>(); + private Map<String, Script> scripts = new LinkedHashMap<String, Script>(); + + /** + * Retrieve the currently active MLContext. This is used internally by + * SystemML via MLContextProxy. + * + * @return the active MLContext + */ + public static MLContext getActiveMLContext() { + return activeMLContext; + } + + /** + * Create an MLContext based on a SparkContext for interaction with SystemML + * on Spark. + * + * @param sparkContext + * SparkContext + */ + public MLContext(SparkContext sparkContext) { + this(sparkContext, false); + } + + /** + * Create an MLContext based on a JavaSparkContext for interaction with + * SystemML on Spark. + * + * @param javaSparkContext + * JavaSparkContext + */ + public MLContext(JavaSparkContext javaSparkContext) { + this(javaSparkContext.sc(), false); + } + + /** + * Create an MLContext based on a SparkContext for interaction with SystemML + * on Spark, optionally monitor performance. + * + * @param sc + * SparkContext object. + * @param monitorPerformance + * {@code true} if performance should be monitored, {@code false} + * otherwise + */ + public MLContext(SparkContext sc, boolean monitorPerformance) { + initMLContext(sc, monitorPerformance); + } + + /** + * Initialize MLContext. Verify Spark version supported, set default + * execution mode, set MLContextProxy, set default config, set compiler + * config, and configure monitoring if needed. + * + * @param sc + * SparkContext object. + * @param monitorPerformance + * {@code true} if performance should be monitored, {@code false} + * otherwise + */ + private void initMLContext(SparkContext sc, boolean monitorPerformance) { + + if (activeMLContext == null) { + System.out.println(MLContextUtil.welcomeMessage()); + } + + this.sc = sc; + MLContextUtil.verifySparkVersionSupported(sc); + // by default, run in hybrid Spark mode for optimal performance + DMLScript.rtplatform = RUNTIME_PLATFORM.HYBRID_SPARK; + + activeMLContext = this; + MLContextProxy.setActive(true); + + MLContextUtil.setDefaultConfig(); + MLContextUtil.setCompilerConfig(); + + if (monitorPerformance) { + SparkListener sparkListener = new SparkListener(sc); + sparkMonitoringUtil = new SparkMonitoringUtil(sparkListener); + sc.addSparkListener(sparkListener); + } + } + + /** + * Clean up the variables from the buffer pool, including evicted files, + * because the buffer pool holds references. + */ + public void clearCache() { + CacheableData.cleanupCacheDir(); + } + + /** + * Reset configuration settings to default settings. + */ + public void resetConfig() { + MLContextUtil.setDefaultConfig(); + } + + /** + * Set configuration property, such as + * {@code setConfigProperty("localtmpdir", "/tmp/systemml")}. + * + * @param propertyName + * property name + * @param propertyValue + * property value + */ + public void setConfigProperty(String propertyName, String propertyValue) { + DMLConfig config = ConfigurationManager.getDMLConfig(); + try { + config.setTextValue(propertyName, propertyValue); + } catch (DMLRuntimeException e) { + throw new MLContextException(e); + } + } + + /** + * Execute a DML or PYDML Script. + * + * @param script + * The DML or PYDML Script object to execute. + */ + public MLResults execute(Script script) { + ScriptExecutor scriptExecutor = new ScriptExecutor(sparkMonitoringUtil); + scriptExecutor.setExplain(explain); + scriptExecutor.setStatistics(statistics); + return execute(script, scriptExecutor); + } + + /** + * Execute a DML or PYDML Script object using a ScriptExecutor. The + * ScriptExecutor class can be extended to allow the modification of the + * default execution pathway. + * + * @param script + * the DML or PYDML Script object + * @param scriptExecutor + * the ScriptExecutor that defines the script execution pathway + */ + public MLResults execute(Script script, ScriptExecutor scriptExecutor) { + try { + executingScript = script; + + Long time = new Long((new Date()).getTime()); + if ((script.getName() == null) || (script.getName().equals(""))) { + script.setName(time.toString()); + } + + MLResults results = scriptExecutor.execute(script); + + String history = MLContextUtil.createHistoryForScript(script, time); + scriptHistoryStrings.add(history); + scripts.put(script.getName(), script); + + return results; + } catch (RuntimeException e) { + throw new MLContextException("Exception when executing script", e); + } + } + + /** + * Set SystemML configuration based on a configuration file. + * + * @param configFilePath + * path to the configuration file + */ + public void setConfig(String configFilePath) { + MLContextUtil.setConfig(configFilePath); + } + + /** + * Obtain the SparkMonitoringUtil if it is available. + * + * @return the SparkMonitoringUtil if it is available. + */ + public SparkMonitoringUtil getSparkMonitoringUtil() { + return sparkMonitoringUtil; + } + + /** + * Obtain the SparkContext associated with this MLContext. + * + * @return the SparkContext associated with this MLContext. + */ + public SparkContext getSparkContext() { + return sc; + } + + /** + * Whether or not an explanation of the DML/PYDML program should be output + * to standard output. + * + * @return {@code true} if explanation should be output, {@code false} + * otherwise + */ + public boolean isExplain() { + return explain; + } + + /** + * Whether or not an explanation of the DML/PYDML program should be output + * to standard output. + * + * @param explain + * {@code true} if explanation should be output, {@code false} + * otherwise + */ + public void setExplain(boolean explain) { + this.explain = explain; + } + + /** + * Used internally by MLContextProxy. + * + */ + public class InternalProxy { + + public void setAppropriateVarsForRead(Expression source, String target) { + boolean isTargetRegistered = isRegisteredAsInput(target); + boolean isReadExpression = (source instanceof DataExpression && ((DataExpression) source).isRead()); + if (isTargetRegistered && isReadExpression) { + DataExpression exp = (DataExpression) source; + // Do not check metadata file for registered reads + exp.setCheckMetadata(false); + + MatrixObject mo = getMatrixObject(target); + if (mo != null) { + int blp = source.getBeginLine(); + int bcp = source.getBeginColumn(); + int elp = source.getEndLine(); + int ecp = source.getEndColumn(); + exp.addVarParam(DataExpression.READROWPARAM, + new IntIdentifier(mo.getNumRows(), source.getFilename(), blp, bcp, elp, ecp)); + exp.addVarParam(DataExpression.READCOLPARAM, + new IntIdentifier(mo.getNumColumns(), source.getFilename(), blp, bcp, elp, ecp)); + exp.addVarParam(DataExpression.READNUMNONZEROPARAM, + new IntIdentifier(mo.getNnz(), source.getFilename(), blp, bcp, elp, ecp)); + exp.addVarParam(DataExpression.DATATYPEPARAM, new StringIdentifier("matrix", source.getFilename(), + blp, bcp, elp, ecp)); + exp.addVarParam(DataExpression.VALUETYPEPARAM, new StringIdentifier("double", source.getFilename(), + blp, bcp, elp, ecp)); + + if (mo.getMetaData() instanceof MatrixFormatMetaData) { + MatrixFormatMetaData metaData = (MatrixFormatMetaData) mo.getMetaData(); + if (metaData.getOutputInfo() == OutputInfo.CSVOutputInfo) { + exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier( + DataExpression.FORMAT_TYPE_VALUE_CSV, source.getFilename(), blp, bcp, elp, ecp)); + } else if (metaData.getOutputInfo() == OutputInfo.TextCellOutputInfo) { + exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier( + DataExpression.FORMAT_TYPE_VALUE_TEXT, source.getFilename(), blp, bcp, elp, ecp)); + } else if (metaData.getOutputInfo() == OutputInfo.BinaryBlockOutputInfo) { + exp.addVarParam( + DataExpression.ROWBLOCKCOUNTPARAM, + new IntIdentifier(mo.getNumRowsPerBlock(), source.getFilename(), blp, bcp, elp, ecp)); + exp.addVarParam(DataExpression.COLUMNBLOCKCOUNTPARAM, + new IntIdentifier(mo.getNumColumnsPerBlock(), source.getFilename(), blp, bcp, elp, + ecp)); + exp.addVarParam(DataExpression.FORMAT_TYPE, new StringIdentifier( + DataExpression.FORMAT_TYPE_VALUE_BINARY, source.getFilename(), blp, bcp, elp, ecp)); + } else { + throw new MLContextException("Unsupported format through MLContext"); + } + } + } + + } + } + + private boolean isRegisteredAsInput(String parameterName) { + if (executingScript != null) { + Set<String> inputVariableNames = executingScript.getInputVariables(); + if (inputVariableNames != null) { + return inputVariableNames.contains(parameterName); + } + } + return false; + } + + private MatrixObject getMatrixObject(String parameterName) { + if (executingScript != null) { + LocalVariableMap symbolTable = executingScript.getSymbolTable(); + if (symbolTable != null) { + Data data = symbolTable.get(parameterName); + if (data instanceof MatrixObject) { + return (MatrixObject) data; + } else { + if (data instanceof ScalarObject) { + return null; + } + } + } + } + throw new MLContextException("getMatrixObject not set for parameter: " + parameterName); + } + + public ArrayList<Instruction> performCleanupAfterRecompilation(ArrayList<Instruction> instructions) { + if (executingScript == null) { + return instructions; + } + Set<String> outputVariableNames = executingScript.getOutputVariables(); + if (outputVariableNames == null) { + return instructions; + } + + for (int i = 0; i < instructions.size(); i++) { + Instruction inst = instructions.get(i); + if (inst instanceof VariableCPInstruction && ((VariableCPInstruction) inst).isRemoveVariable()) { + VariableCPInstruction varInst = (VariableCPInstruction) inst; + for (String outputVariableName : outputVariableNames) + if (varInst.isRemoveVariable(outputVariableName)) { + instructions.remove(i); + i--; + break; + } + } + } + return instructions; + } + } + + /** + * Used internally by MLContextProxy. + * + */ + public InternalProxy getInternalProxy() { + return internalProxy; + } + + /** + * Whether or not statistics of the DML/PYDML program execution should be + * output to standard output. + * + * @return {@code true} if statistics should be output, {@code false} + * otherwise + */ + public boolean isStatistics() { + return statistics; + } + + /** + * Whether or not statistics of the DML/PYDML program execution should be + * output to standard output. + * + * @param statistics + * {@code true} if statistics should be output, {@code false} + * otherwise + */ + public void setStatistics(boolean statistics) { + DMLScript.STATISTICS = statistics; + this.statistics = statistics; + } + + /** + * Obtain a map of the scripts that have executed. + * + * @return a map of the scripts that have executed + */ + public Map<String, Script> getScripts() { + return scripts; + } + + /** + * Obtain a script that has executed by name. + * + * @param name + * the name of the script + * @return the script corresponding to the name + */ + public Script getScriptByName(String name) { + Script script = scripts.get(name); + if (script == null) { + throw new MLContextException("Script with name '" + name + "' not found."); + } + return script; + } + + /** + * Display the history of scripts that have executed. + * + * @return the history of scripts that have executed + */ + public String history() { + return MLContextUtil.displayScriptHistory(scriptHistoryStrings); + } + + /** + * Clear all the scripts, removing them from the history, and clear the + * cache. + */ + public void clear() { + Set<String> scriptNames = scripts.keySet(); + for (String scriptName : scriptNames) { + Script script = scripts.get(scriptName); + script.clearAll(); + } + + scripts.clear(); + scriptHistoryStrings.clear(); + + clearCache(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java new file mode 100644 index 0000000..33226d2 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextConversionUtil.java @@ -0,0 +1,720 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.spark.Accumulator; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +import org.apache.sysml.api.MLContextProxy; +import org.apache.sysml.parser.Expression.ValueType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.CacheException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysml.runtime.instructions.spark.data.RDDObject; +import org.apache.sysml.runtime.instructions.spark.functions.ConvertStringToLongTextPair; +import org.apache.sysml.runtime.instructions.spark.functions.CopyBlockPairFunction; +import org.apache.sysml.runtime.instructions.spark.functions.CopyTextInputFunction; +import org.apache.sysml.runtime.instructions.spark.utils.RDDAggregateUtils; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameAnalysisFunction; +import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt.DataFrameToBinaryBlockFunction; +import org.apache.sysml.runtime.matrix.MatrixCharacteristics; +import org.apache.sysml.runtime.matrix.MatrixFormatMetaData; +import org.apache.sysml.runtime.matrix.data.IJV; +import org.apache.sysml.runtime.matrix.data.InputInfo; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; +import org.apache.sysml.runtime.matrix.data.OutputInfo; +import org.apache.sysml.runtime.util.DataConverter; +import org.apache.sysml.runtime.util.UtilFunctions; + +import scala.collection.JavaConversions; +import scala.reflect.ClassTag; + +/** + * Utility class containing methods to perform data conversions. + * + */ +public class MLContextConversionUtil { + + /** + * Convert a two-dimensional double array to a {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param doubleMatrix + * matrix of double values + * @return the two-dimensional double matrix converted to a + * {@code MatrixObject} + */ + public static MatrixObject doubleMatrixToMatrixObject(String variableName, double[][] doubleMatrix) { + return doubleMatrixToMatrixObject(variableName, doubleMatrix, null); + } + + /** + * Convert a two-dimensional double array to a {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param doubleMatrix + * matrix of double values + * @param matrixMetadata + * the matrix metadata + * @return the two-dimensional double matrix converted to a + * {@code MatrixObject} + */ + public static MatrixObject doubleMatrixToMatrixObject(String variableName, double[][] doubleMatrix, + MatrixMetadata matrixMetadata) { + try { + MatrixBlock matrixBlock = DataConverter.convertToMatrixBlock(doubleMatrix); + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + } else { + matrixCharacteristics = new MatrixCharacteristics(matrixBlock.getNumRows(), + matrixBlock.getNumColumns(), MLContextUtil.defaultBlockSize(), MLContextUtil.defaultBlockSize()); + } + + MatrixFormatMetaData meta = new MatrixFormatMetaData(matrixCharacteristics, + OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo); + MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/" + + variableName, meta); + matrixObject.acquireModify(matrixBlock); + matrixObject.release(); + return matrixObject; + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception converting double[][] array to MatrixObject", e); + } + } + + /** + * Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a + * {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param binaryBlocks + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} representation + * of a binary-block matrix + * @return the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix + * converted to a {@code MatrixObject} + */ + public static MatrixObject binaryBlocksToMatrixObject(String variableName, + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks) { + return binaryBlocksToMatrixObject(variableName, binaryBlocks, null); + } + + /** + * Convert a {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} to a + * {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param binaryBlocks + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} representation + * of a binary-block matrix + * @param matrixMetadata + * the matrix metadata + * @return the {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} matrix + * converted to a {@code MatrixObject} + */ + public static MatrixObject binaryBlocksToMatrixObject(String variableName, + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks, MatrixMetadata matrixMetadata) { + + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + } else { + matrixCharacteristics = new MatrixCharacteristics(); + } + + JavaPairRDD<MatrixIndexes, MatrixBlock> javaPairRdd = binaryBlocks.mapToPair(new CopyBlockPairFunction()); + + MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, MLContextUtil.scratchSpace() + "/" + "temp_" + + System.nanoTime(), new MatrixFormatMetaData(matrixCharacteristics, OutputInfo.BinaryBlockOutputInfo, + InputInfo.BinaryBlockInputInfo)); + matrixObject.setRDDHandle(new RDDObject(javaPairRdd, variableName)); + return matrixObject; + } + + /** + * Convert a {@code DataFrame} to a {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param dataFrame + * the Spark {@code DataFrame} + * @return the {@code DataFrame} matrix converted to a converted to a + * {@code MatrixObject} + */ + public static MatrixObject dataFrameToMatrixObject(String variableName, DataFrame dataFrame) { + return dataFrameToMatrixObject(variableName, dataFrame, null); + } + + /** + * Convert a {@code DataFrame} to a {@code MatrixObject}. + * + * @param variableName + * name of the variable associated with the matrix + * @param dataFrame + * the Spark {@code DataFrame} + * @param matrixMetadata + * the matrix metadata + * @return the {@code DataFrame} matrix converted to a converted to a + * {@code MatrixObject} + */ + public static MatrixObject dataFrameToMatrixObject(String variableName, DataFrame dataFrame, + MatrixMetadata matrixMetadata) { + if (matrixMetadata == null) { + matrixMetadata = new MatrixMetadata(); + } + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = MLContextConversionUtil.dataFrameToBinaryBlocks( + dataFrame, matrixMetadata); + MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(variableName, binaryBlock, + matrixMetadata); + return matrixObject; + } + + /** + * Convert a {@code DataFrame} to a + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} binary-block matrix. + * + * @param dataFrame + * the Spark {@code DataFrame} + * @return the {@code DataFrame} matrix converted to a + * {@code JavaPairRDD<MatrixIndexes, + * MatrixBlock>} binary-block matrix + */ + public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame) { + return dataFrameToBinaryBlocks(dataFrame, null); + } + + /** + * Convert a {@code DataFrame} to a + * {@code JavaPairRDD<MatrixIndexes, MatrixBlock>} binary-block matrix. + * + * @param dataFrame + * the Spark {@code DataFrame} + * @param matrixMetadata + * the matrix metadata + * @return the {@code DataFrame} matrix converted to a + * {@code JavaPairRDD<MatrixIndexes, + * MatrixBlock>} binary-block matrix + */ + public static JavaPairRDD<MatrixIndexes, MatrixBlock> dataFrameToBinaryBlocks(DataFrame dataFrame, + MatrixMetadata matrixMetadata) { + + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + if (matrixCharacteristics == null) { + matrixCharacteristics = new MatrixCharacteristics(); + } + } else { + matrixCharacteristics = new MatrixCharacteristics(); + } + determineDataFrameDimensionsIfNeeded(dataFrame, matrixCharacteristics); + if (matrixMetadata != null) { + // so external reference can be updated with the metadata + matrixMetadata.setMatrixCharacteristics(matrixCharacteristics); + } + + JavaRDD<Row> javaRDD = dataFrame.javaRDD(); + JavaPairRDD<Row, Long> prepinput = javaRDD.zipWithIndex(); + JavaPairRDD<MatrixIndexes, MatrixBlock> out = prepinput.mapPartitionsToPair(new DataFrameToBinaryBlockFunction( + matrixCharacteristics, false)); + out = RDDAggregateUtils.mergeByKey(out); + return out; + } + + /** + * If the {@code DataFrame} dimensions aren't present in the + * {@code MatrixCharacteristics} metadata, determine the dimensions and + * place them in the {@code MatrixCharacteristics} metadata. + * + * @param dataFrame + * the Spark {@code DataFrame} + * @param matrixCharacteristics + * the matrix metadata + */ + public static void determineDataFrameDimensionsIfNeeded(DataFrame dataFrame, + MatrixCharacteristics matrixCharacteristics) { + if (!matrixCharacteristics.dimsKnown(true)) { + // only available to the new MLContext API, not the old API + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sparkContext = activeMLContext.getSparkContext(); + @SuppressWarnings("resource") + JavaSparkContext javaSparkContext = new JavaSparkContext(sparkContext); + + Accumulator<Double> aNnz = javaSparkContext.accumulator(0L); + JavaRDD<Row> javaRDD = dataFrame.javaRDD().map(new DataFrameAnalysisFunction(aNnz, false)); + long numRows = javaRDD.count(); + long numColumns = dataFrame.columns().length; + long numNonZeros = UtilFunctions.toLong(aNnz.value()); + matrixCharacteristics.set(numRows, numColumns, matrixCharacteristics.getRowsPerBlock(), + matrixCharacteristics.getColsPerBlock(), numNonZeros); + } + } + + /** + * Convert a {@code JavaRDD<String>} in CSV format to a {@code MatrixObject} + * + * @param variableName + * name of the variable associated with the matrix + * @param javaRDD + * the Java RDD of strings + * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject javaRDDStringCSVToMatrixObject(String variableName, JavaRDD<String> javaRDD) { + return javaRDDStringCSVToMatrixObject(variableName, javaRDD, null); + } + + /** + * Convert a {@code JavaRDD<String>} in CSV format to a {@code MatrixObject} + * + * @param variableName + * name of the variable associated with the matrix + * @param javaRDD + * the Java RDD of strings + * @param matrixMetadata + * matrix metadata + * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject javaRDDStringCSVToMatrixObject(String variableName, JavaRDD<String> javaRDD, + MatrixMetadata matrixMetadata) { + JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair()); + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + } else { + matrixCharacteristics = new MatrixCharacteristics(); + } + MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, null, new MatrixFormatMetaData( + matrixCharacteristics, OutputInfo.CSVOutputInfo, InputInfo.CSVInputInfo)); + JavaPairRDD<LongWritable, Text> javaPairRDD2 = javaPairRDD.mapToPair(new CopyTextInputFunction()); + matrixObject.setRDDHandle(new RDDObject(javaPairRDD2, variableName)); + return matrixObject; + } + + /** + * Convert a {@code JavaRDD<String>} in IJV format to a {@code MatrixObject} + * . Note that metadata is required for IJV format. + * + * @param variableName + * name of the variable associated with the matrix + * @param javaRDD + * the Java RDD of strings + * @param matrixMetadata + * matrix metadata + * @return the {@code JavaRDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject javaRDDStringIJVToMatrixObject(String variableName, JavaRDD<String> javaRDD, + MatrixMetadata matrixMetadata) { + JavaPairRDD<LongWritable, Text> javaPairRDD = javaRDD.mapToPair(new ConvertStringToLongTextPair()); + MatrixCharacteristics matrixCharacteristics; + if (matrixMetadata != null) { + matrixCharacteristics = matrixMetadata.asMatrixCharacteristics(); + } else { + matrixCharacteristics = new MatrixCharacteristics(); + } + MatrixObject matrixObject = new MatrixObject(ValueType.DOUBLE, null, new MatrixFormatMetaData( + matrixCharacteristics, OutputInfo.TextCellOutputInfo, InputInfo.TextCellInputInfo)); + JavaPairRDD<LongWritable, Text> javaPairRDD2 = javaPairRDD.mapToPair(new CopyTextInputFunction()); + matrixObject.setRDDHandle(new RDDObject(javaPairRDD2, variableName)); + return matrixObject; + } + + /** + * Convert a {@code RDD<String>} in CSV format to a {@code MatrixObject} + * + * @param variableName + * name of the variable associated with the matrix + * @param rdd + * the RDD of strings + * @return the {@code RDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject rddStringCSVToMatrixObject(String variableName, RDD<String> rdd) { + return rddStringCSVToMatrixObject(variableName, rdd, null); + } + + /** + * Convert a {@code RDD<String>} in CSV format to a {@code MatrixObject} + * + * @param variableName + * name of the variable associated with the matrix + * @param rdd + * the RDD of strings + * @param matrixMetadata + * matrix metadata + * @return the {@code RDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject rddStringCSVToMatrixObject(String variableName, RDD<String> rdd, + MatrixMetadata matrixMetadata) { + ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); + JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag); + return javaRDDStringCSVToMatrixObject(variableName, javaRDD, matrixMetadata); + } + + /** + * Convert a {@code RDD<String>} in IJV format to a {@code MatrixObject}. + * Note that metadata is required for IJV format. + * + * @param variableName + * name of the variable associated with the matrix + * @param rdd + * the RDD of strings + * @param matrixMetadata + * matrix metadata + * @return the {@code RDD<String>} converted to a {@code MatrixObject} + */ + public static MatrixObject rddStringIJVToMatrixObject(String variableName, RDD<String> rdd, + MatrixMetadata matrixMetadata) { + ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); + JavaRDD<String> javaRDD = JavaRDD.fromRDD(rdd, tag); + return javaRDDStringIJVToMatrixObject(variableName, javaRDD, matrixMetadata); + } + + /** + * Convert an {@code BinaryBlockMatrix} to a {@code JavaRDD<String>} in IVJ + * format. + * + * @param binaryBlockMatrix + * the {@code BinaryBlockMatrix} + * @return the {@code BinaryBlockMatrix} converted to a + * {@code JavaRDD<String>} + */ + public static JavaRDD<String> binaryBlockMatrixToJavaRDDStringIJV(BinaryBlockMatrix binaryBlockMatrix) { + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = binaryBlockMatrix.getBinaryBlocks(); + MatrixCharacteristics matrixCharacteristics = binaryBlockMatrix.getMatrixCharacteristics(); + try { + JavaRDD<String> javaRDDString = RDDConverterUtilsExt.binaryBlockToStringRDD(binaryBlock, + matrixCharacteristics, "text"); + return javaRDDString; + } catch (DMLRuntimeException e) { + throw new MLContextException("Exception converting BinaryBlockMatrix to JavaRDD<String> (ijv)", e); + } + } + + /** + * Convert an {@code BinaryBlockMatrix} to a {@code RDD<String>} in IVJ + * format. + * + * @param binaryBlockMatrix + * the {@code BinaryBlockMatrix} + * @return the {@code BinaryBlockMatrix} converted to a {@code RDD<String>} + */ + public static RDD<String> binaryBlockMatrixToRDDStringIJV(BinaryBlockMatrix binaryBlockMatrix) { + JavaRDD<String> javaRDD = binaryBlockMatrixToJavaRDDStringIJV(binaryBlockMatrix); + RDD<String> rdd = JavaRDD.toRDD(javaRDD); + return rdd; + } + + /** + * Convert a {@code MatrixObject} to a {@code JavaRDD<String>} in CSV + * format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code JavaRDD<String>} + */ + public static JavaRDD<String> matrixObjectToJavaRDDStringCSV(MatrixObject matrixObject) { + List<String> list = matrixObjectToListStringCSV(matrixObject); + + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sc = activeMLContext.getSparkContext(); + @SuppressWarnings("resource") + JavaSparkContext jsc = new JavaSparkContext(sc); + JavaRDD<String> javaRDDStringCSV = jsc.parallelize(list); + return javaRDDStringCSV; + } + + /** + * Convert a {@code MatrixObject} to a {@code JavaRDD<String>} in IJV + * format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code JavaRDD<String>} + */ + public static JavaRDD<String> matrixObjectToJavaRDDStringIJV(MatrixObject matrixObject) { + List<String> list = matrixObjectToListStringIJV(matrixObject); + + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sc = activeMLContext.getSparkContext(); + @SuppressWarnings("resource") + JavaSparkContext jsc = new JavaSparkContext(sc); + JavaRDD<String> javaRDDStringCSV = jsc.parallelize(list); + return javaRDDStringCSV; + } + + /** + * Convert a {@code MatrixObject} to a {@code RDD<String>} in IJV format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code RDD<String>} + */ + public static RDD<String> matrixObjectToRDDStringIJV(MatrixObject matrixObject) { + + // NOTE: The following works when called from Java but does not + // currently work when called from Spark Shell (when you call + // collect() on the RDD<String>). + // + // JavaRDD<String> javaRDD = jsc.parallelize(list); + // RDD<String> rdd = JavaRDD.toRDD(javaRDD); + // + // Therefore, we call parallelize() on the SparkContext rather than + // the JavaSparkContext to produce the RDD<String> for Scala. + + List<String> list = matrixObjectToListStringIJV(matrixObject); + + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sc = activeMLContext.getSparkContext(); + ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); + RDD<String> rddString = sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return rddString; + } + + /** + * Convert a {@code MatrixObject} to a {@code RDD<String>} in CSV format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code RDD<String>} + */ + public static RDD<String> matrixObjectToRDDStringCSV(MatrixObject matrixObject) { + + // NOTE: The following works when called from Java but does not + // currently work when called from Spark Shell (when you call + // collect() on the RDD<String>). + // + // JavaRDD<String> javaRDD = jsc.parallelize(list); + // RDD<String> rdd = JavaRDD.toRDD(javaRDD); + // + // Therefore, we call parallelize() on the SparkContext rather than + // the JavaSparkContext to produce the RDD<String> for Scala. + + List<String> list = matrixObjectToListStringCSV(matrixObject); + + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sc = activeMLContext.getSparkContext(); + ClassTag<String> tag = scala.reflect.ClassTag$.MODULE$.apply(String.class); + RDD<String> rddString = sc.parallelize(JavaConversions.asScalaBuffer(list), sc.defaultParallelism(), tag); + return rddString; + } + + /** + * Convert a {@code MatrixObject} to a {@code List<String>} in CSV format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code List<String>} + */ + public static List<String> matrixObjectToListStringCSV(MatrixObject matrixObject) { + try { + MatrixBlock mb = matrixObject.acquireRead(); + + int rows = mb.getNumRows(); + int cols = mb.getNumColumns(); + List<String> list = new ArrayList<String>(); + + if (mb.getNonZeros() > 0) { + if (mb.isInSparseFormat()) { + Iterator<IJV> iter = mb.getSparseBlockIterator(); + int prevCellRow = -1; + StringBuilder sb = null; + while (iter.hasNext()) { + IJV cell = iter.next(); + int i = cell.getI(); + double v = cell.getV(); + if (i > prevCellRow) { + if (sb == null) { + sb = new StringBuilder(); + } else { + list.add(sb.toString()); + sb = new StringBuilder(); + } + sb.append(v); + prevCellRow = i; + } else if (i == prevCellRow) { + sb.append(","); + sb.append(v); + } + } + if (sb != null) { + list.add(sb.toString()); + } + } else { + for (int i = 0; i < rows; i++) { + StringBuilder sb = new StringBuilder(); + for (int j = 0; j < cols; j++) { + if (j > 0) { + sb.append(","); + } + sb.append(mb.getValueDenseUnsafe(i, j)); + } + list.add(sb.toString()); + } + } + } + + matrixObject.release(); + return list; + } catch (CacheException e) { + throw new MLContextException("Cache exception while converting matrix object to List<String> CSV format", e); + } + } + + /** + * Convert a {@code MatrixObject} to a {@code List<String>} in IJV format. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code List<String>} + */ + public static List<String> matrixObjectToListStringIJV(MatrixObject matrixObject) { + try { + MatrixBlock mb = matrixObject.acquireRead(); + + int rows = mb.getNumRows(); + int cols = mb.getNumColumns(); + List<String> list = new ArrayList<String>(); + + if (mb.getNonZeros() > 0) { + if (mb.isInSparseFormat()) { + Iterator<IJV> iter = mb.getSparseBlockIterator(); + StringBuilder sb = null; + while (iter.hasNext()) { + IJV cell = iter.next(); + sb = new StringBuilder(); + sb.append(cell.getI() + 1); + sb.append(" "); + sb.append(cell.getJ() + 1); + sb.append(" "); + sb.append(cell.getV()); + list.add(sb.toString()); + } + } else { + StringBuilder sb = null; + for (int i = 0; i < rows; i++) { + sb = new StringBuilder(); + for (int j = 0; j < cols; j++) { + sb = new StringBuilder(); + sb.append(i + 1); + sb.append(" "); + sb.append(j + 1); + sb.append(" "); + sb.append(mb.getValueDenseUnsafe(i, j)); + list.add(sb.toString()); + } + } + } + } + + matrixObject.release(); + return list; + } catch (CacheException e) { + throw new MLContextException("Cache exception while converting matrix object to List<String> IJV format", e); + } + } + + /** + * Convert a {@code MatrixObject} to a two-dimensional double array. + * + * @param matrixObject + * the {@code MatrixObject} + * @return the {@code MatrixObject} converted to a {@code double[][]} + */ + public static double[][] matrixObjectToDoubleMatrix(MatrixObject matrixObject) { + try { + MatrixBlock mb = matrixObject.acquireRead(); + double[][] matrix = DataConverter.convertToDoubleMatrix(mb); + matrixObject.release(); + return matrix; + } catch (CacheException e) { + throw new MLContextException("Cache exception while converting matrix object to double matrix", e); + } + } + + /** + * Convert a {@code MatrixObject} to a {@code DataFrame}. + * + * @param matrixObject + * the {@code MatrixObject} + * @param sparkExecutionContext + * the Spark execution context + * @return the {@code MatrixObject} converted to a {@code DataFrame} + */ + public static DataFrame matrixObjectToDataFrame(MatrixObject matrixObject, + SparkExecutionContext sparkExecutionContext) { + try { + @SuppressWarnings("unchecked") + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlockMatrix = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext + .getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo); + MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics(); + + MLContext activeMLContext = (MLContext) MLContextProxy.getActiveMLContext(); + SparkContext sc = activeMLContext.getSparkContext(); + SQLContext sqlContext = new SQLContext(sc); + DataFrame df = RDDConverterUtilsExt.binaryBlockToDataFrame(binaryBlockMatrix, matrixCharacteristics, + sqlContext); + return df; + } catch (DMLRuntimeException e) { + throw new MLContextException("DMLRuntimeException while converting matrix object to DataFrame", e); + } + } + + /** + * Convert a {@code MatrixObject} to a {@code BinaryBlockMatrix}. + * + * @param matrixObject + * the {@code MatrixObject} + * @param sparkExecutionContext + * the Spark execution context + * @return the {@code MatrixObject} converted to a {@code BinaryBlockMatrix} + */ + public static BinaryBlockMatrix matrixObjectToBinaryBlockMatrix(MatrixObject matrixObject, + SparkExecutionContext sparkExecutionContext) { + try { + @SuppressWarnings("unchecked") + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlock = (JavaPairRDD<MatrixIndexes, MatrixBlock>) sparkExecutionContext + .getRDDHandleForMatrixObject(matrixObject, InputInfo.BinaryBlockInputInfo); + MatrixCharacteristics matrixCharacteristics = matrixObject.getMatrixCharacteristics(); + BinaryBlockMatrix binaryBlockMatrix = new BinaryBlockMatrix(binaryBlock, matrixCharacteristics); + return binaryBlockMatrix; + } catch (DMLRuntimeException e) { + throw new MLContextException("DMLRuntimeException while converting matrix object to BinaryBlockMatrix", e); + } + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java new file mode 100644 index 0000000..63e6b64 --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextException.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +/** + * Uncaught exception representing SystemML exceptions that occur through the + * MLContext API + * + */ +public class MLContextException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public MLContextException() { + super(); + } + + public MLContextException(String message, Throwable cause) { + super(message, cause); + } + + public MLContextException(String message) { + super(message); + } + + public MLContextException(Throwable cause) { + super(cause); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/457bbd3a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java new file mode 100644 index 0000000..feb616e --- /dev/null +++ b/src/main/java/org/apache/sysml/api/mlcontext/MLContextUtil.java @@ -0,0 +1,844 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.api.mlcontext; + +import java.io.FileNotFoundException; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Scanner; +import java.util.Set; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.text.WordUtils; +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.rdd.RDD; +import org.apache.spark.sql.DataFrame; +import org.apache.sysml.conf.CompilerConfig; +import org.apache.sysml.conf.CompilerConfig.ConfigType; +import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.conf.DMLConfig; +import org.apache.sysml.parser.ParseException; +import org.apache.sysml.runtime.controlprogram.LocalVariableMap; +import org.apache.sysml.runtime.controlprogram.caching.FrameObject; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.instructions.cp.BooleanObject; +import org.apache.sysml.runtime.instructions.cp.Data; +import org.apache.sysml.runtime.instructions.cp.DoubleObject; +import org.apache.sysml.runtime.instructions.cp.IntObject; +import org.apache.sysml.runtime.instructions.cp.StringObject; +import org.apache.sysml.runtime.matrix.data.MatrixBlock; +import org.apache.sysml.runtime.matrix.data.MatrixIndexes; + +/** + * Utility class containing methods for working with the MLContext API. + * + */ +public final class MLContextUtil { + + /** + * Basic data types supported by the MLContext API + */ + @SuppressWarnings("rawtypes") + public static final Class[] BASIC_DATA_TYPES = { Integer.class, Boolean.class, Double.class, String.class }; + + /** + * Complex data types supported by the MLContext API + */ + @SuppressWarnings("rawtypes") + public static final Class[] COMPLEX_DATA_TYPES = { JavaRDD.class, RDD.class, DataFrame.class, + BinaryBlockMatrix.class, Matrix.class, (new double[][] {}).getClass() }; + + /** + * All data types supported by the MLContext API + */ + @SuppressWarnings("rawtypes") + public static final Class[] ALL_SUPPORTED_DATA_TYPES = (Class[]) ArrayUtils.addAll(BASIC_DATA_TYPES, + COMPLEX_DATA_TYPES); + + /** + * Compare two version strings (ie, "1.4.0" and "1.4.1"). + * + * @param versionStr1 + * First version string. + * @param versionStr2 + * Second version string. + * @return If versionStr1 is less than versionStr2, return {@code -1}. If + * versionStr1 equals versionStr2, return {@code 0}. If versionStr1 + * is greater than versionStr2, return {@code 1}. + * @throws MLContextException + * if versionStr1 or versionStr2 is {@code null} + */ + private static int compareVersion(String versionStr1, String versionStr2) { + if (versionStr1 == null) { + throw new MLContextException("First version argument to compareVersion() is null"); + } + if (versionStr2 == null) { + throw new MLContextException("Second version argument to compareVersion() is null"); + } + + Scanner scanner1 = null; + Scanner scanner2 = null; + try { + scanner1 = new Scanner(versionStr1); + scanner2 = new Scanner(versionStr2); + scanner1.useDelimiter("\\."); + scanner2.useDelimiter("\\."); + + while (scanner1.hasNextInt() && scanner2.hasNextInt()) { + int version1 = scanner1.nextInt(); + int version2 = scanner2.nextInt(); + if (version1 < version2) { + return -1; + } else if (version1 > version2) { + return 1; + } + } + + return scanner1.hasNextInt() ? 1 : 0; + } finally { + scanner1.close(); + scanner2.close(); + } + } + + /** + * Determine whether the Spark version is supported. + * + * @param sparkVersion + * Spark version string (ie, "1.5.0"). + * @return {@code true} if Spark version supported; otherwise {@code false}. + */ + public static boolean isSparkVersionSupported(String sparkVersion) { + if (compareVersion(sparkVersion, MLContext.SYSTEMML_MINIMUM_SPARK_VERSION) < 0) { + return false; + } else { + return true; + } + } + + /** + * Check that the Spark version is supported. If it isn't supported, throw + * an MLContextException. + * + * @param sc + * SparkContext + * @throws MLContextException + * thrown if Spark version isn't supported + */ + public static void verifySparkVersionSupported(SparkContext sc) { + if (!MLContextUtil.isSparkVersionSupported(sc.version())) { + throw new MLContextException("SystemML requires Spark " + MLContext.SYSTEMML_MINIMUM_SPARK_VERSION + + " or greater"); + } + } + + /** + * Set default SystemML configuration properties. + */ + public static void setDefaultConfig() { + ConfigurationManager.setGlobalConfig(new DMLConfig()); + } + + /** + * Set SystemML configuration properties based on a configuration file. + * + * @param configFilePath + * Path to configuration file. + * @throws MLContextException + * if configuration file was not found or a parse exception + * occurred + */ + public static void setConfig(String configFilePath) { + try { + DMLConfig config = new DMLConfig(configFilePath); + ConfigurationManager.setGlobalConfig(config); + } catch (ParseException e) { + throw new MLContextException("Parse Exception when setting config", e); + } catch (FileNotFoundException e) { + throw new MLContextException("File not found (" + configFilePath + ") when setting config", e); + } + } + + /** + * Set SystemML compiler configuration properties for MLContext + */ + public static void setCompilerConfig() { + CompilerConfig compilerConfig = new CompilerConfig(); + compilerConfig.set(ConfigType.IGNORE_UNSPECIFIED_ARGS, true); + compilerConfig.set(ConfigType.REJECT_READ_WRITE_UNKNOWNS, false); + compilerConfig.set(ConfigType.ALLOW_CSE_PERSISTENT_READS, false); + ConfigurationManager.setGlobalConfig(compilerConfig); + } + + /** + * Convenience method to generate a {@code Map<String, Object>} of key/value + * pairs. + * <p> + * Example:<br> + * {@code Map<String, Object> inputMap = MLContextUtil.generateInputMap("A", 1, "B", "two", "C", 3);} + * <br> + * <br> + * This is equivalent to:<br> + * <code>Map<String, Object> inputMap = new LinkedHashMap<String, Object>(){{ + * <br>put("A", 1); + * <br>put("B", "two"); + * <br>put("C", 3); + * <br>}};</code> + * + * @param objs + * List of String/Object pairs + * @return Map of String/Object pairs + * @throws MLContextException + * if the number of arguments is not an even number + */ + public static Map<String, Object> generateInputMap(Object... objs) { + int len = objs.length; + if ((len & 1) == 1) { + throw new MLContextException("The number of arguments needs to be an even number"); + } + Map<String, Object> map = new LinkedHashMap<String, Object>(); + int i = 0; + while (i < len) { + map.put((String) objs[i++], objs[i++]); + } + return map; + } + + /** + * Verify that the types of input values are supported. + * + * @param inputs + * Map of String/Object pairs + * @throws MLContextException + * if an input value type is not supported + */ + public static void checkInputValueTypes(Map<String, Object> inputs) { + for (Entry<String, Object> entry : inputs.entrySet()) { + checkInputValueType(entry.getKey(), entry.getValue()); + } + } + + /** + * Verify that the type of input value is supported. + * + * @param name + * The name of the input + * @param value + * The value of the input + * @throws MLContextException + * if the input value type is not supported + */ + public static void checkInputValueType(String name, Object value) { + + if (name == null) { + throw new MLContextException("No input name supplied"); + } else if (value == null) { + throw new MLContextException("No input value supplied"); + } + + Object o = value; + boolean supported = false; + for (Class<?> clazz : ALL_SUPPORTED_DATA_TYPES) { + if (o.getClass().equals(clazz)) { + supported = true; + break; + } else if (clazz.isAssignableFrom(o.getClass())) { + supported = true; + break; + } + } + if (!supported) { + throw new MLContextException("Input name (\"" + value + "\") value type not supported: " + o.getClass()); + } + } + + /** + * Verify that the type of input parameter value is supported. + * + * @param parameterName + * The name of the input parameter + * @param parameterValue + * The value of the input parameter + * @throws MLContextException + * if the input parameter value type is not supported + */ + public static void checkInputParameterType(String parameterName, Object parameterValue) { + + if (parameterName == null) { + throw new MLContextException("No parameter name supplied"); + } else if (parameterValue == null) { + throw new MLContextException("No parameter value supplied"); + } else if (!parameterName.startsWith("$")) { + throw new MLContextException("Input parameter name must start with a $"); + } + + Object o = parameterValue; + boolean supported = false; + for (Class<?> clazz : BASIC_DATA_TYPES) { + if (o.getClass().equals(clazz)) { + supported = true; + break; + } else if (clazz.isAssignableFrom(o.getClass())) { + supported = true; + break; + } + } + if (!supported) { + throw new MLContextException("Input parameter (\"" + parameterName + "\") value type not supported: " + + o.getClass()); + } + } + + /** + * Is the object one of the supported basic data types? (Integer, Boolean, + * Double, String) + * + * @param object + * the object type to be examined + * @return {@code true} if type is a basic data type; otherwise + * {@code false}. + */ + public static boolean isBasicType(Object object) { + for (Class<?> clazz : BASIC_DATA_TYPES) { + if (object.getClass().equals(clazz)) { + return true; + } else if (clazz.isAssignableFrom(object.getClass())) { + return true; + } + } + return false; + } + + /** + * Is the object one of the supported complex data types? (JavaRDD, RDD, + * DataFrame, BinaryBlockMatrix, Matrix, double[][]) + * + * @param object + * the object type to be examined + * @return {@code true} if type is a complexe data type; otherwise + * {@code false}. + */ + public static boolean isComplexType(Object object) { + for (Class<?> clazz : COMPLEX_DATA_TYPES) { + if (object.getClass().equals(clazz)) { + return true; + } else if (clazz.isAssignableFrom(object.getClass())) { + return true; + } + } + return false; + } + + /** + * Converts non-string basic input parameter values to strings to pass to + * the parser. + * + * @param basicInputParameterMap + * map of input parameters + * @param scriptType + * {@code ScriptType.DML} or {@code ScriptType.PYDML} + * @return map of String/String name/value pairs + */ + public static Map<String, String> convertInputParametersForParser(Map<String, Object> basicInputParameterMap, + ScriptType scriptType) { + if (basicInputParameterMap == null) { + return null; + } + if (scriptType == null) { + throw new MLContextException("ScriptType needs to be specified"); + } + Map<String, String> convertedMap = new HashMap<String, String>(); + for (Entry<String, Object> entry : basicInputParameterMap.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + if (value == null) { + throw new MLContextException("Input parameter value is null for: " + entry.getKey()); + } else if (value instanceof Integer) { + convertedMap.put(key, Integer.toString((Integer) value)); + } else if (value instanceof Boolean) { + if (scriptType == ScriptType.DML) { + convertedMap.put(key, String.valueOf((Boolean) value).toUpperCase()); + } else { + convertedMap.put(key, WordUtils.capitalize(String.valueOf((Boolean) value))); + } + } else if (value instanceof Double) { + convertedMap.put(key, Double.toString((Double) value)); + } else if (value instanceof String) { + convertedMap.put(key, (String) value); + } + } + return convertedMap; + } + + /** + * Convert input types to internal SystemML representations + * + * @param parameterName + * The name of the input parameter + * @param parameterValue + * The value of the input parameter + * @return input in SystemML data representation + */ + public static Data convertInputType(String parameterName, Object parameterValue) { + return convertInputType(parameterName, parameterValue, null); + } + + /** + * Convert input types to internal SystemML representations + * + * @param parameterName + * The name of the input parameter + * @param parameterValue + * The value of the input parameter + * @param matrixMetadata + * matrix metadata + * @return input in SystemML data representation + */ + public static Data convertInputType(String parameterName, Object parameterValue, MatrixMetadata matrixMetadata) { + String name = parameterName; + Object value = parameterValue; + if (name == null) { + throw new MLContextException("Input parameter name is null"); + } else if (value == null) { + throw new MLContextException("Input parameter value is null for: " + parameterName); + } else if (value instanceof JavaRDD<?>) { + @SuppressWarnings("unchecked") + JavaRDD<String> javaRDD = (JavaRDD<String>) value; + MatrixObject matrixObject; + if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) { + matrixObject = MLContextConversionUtil.javaRDDStringIJVToMatrixObject(name, javaRDD, matrixMetadata); + } else { + matrixObject = MLContextConversionUtil.javaRDDStringCSVToMatrixObject(name, javaRDD, matrixMetadata); + } + return matrixObject; + } else if (value instanceof RDD<?>) { + @SuppressWarnings("unchecked") + RDD<String> rdd = (RDD<String>) value; + MatrixObject matrixObject; + if ((matrixMetadata != null) && (matrixMetadata.getMatrixFormat() == MatrixFormat.IJV)) { + matrixObject = MLContextConversionUtil.rddStringIJVToMatrixObject(name, rdd, matrixMetadata); + } else { + matrixObject = MLContextConversionUtil.rddStringCSVToMatrixObject(name, rdd, matrixMetadata); + } + + return matrixObject; + } else if (value instanceof DataFrame) { + DataFrame dataFrame = (DataFrame) value; + MatrixObject matrixObject = MLContextConversionUtil + .dataFrameToMatrixObject(name, dataFrame, matrixMetadata); + return matrixObject; + } else if (value instanceof BinaryBlockMatrix) { + BinaryBlockMatrix binaryBlockMatrix = (BinaryBlockMatrix) value; + if (matrixMetadata == null) { + matrixMetadata = binaryBlockMatrix.getMatrixMetadata(); + } + JavaPairRDD<MatrixIndexes, MatrixBlock> binaryBlocks = binaryBlockMatrix.getBinaryBlocks(); + MatrixObject matrixObject = MLContextConversionUtil.binaryBlocksToMatrixObject(name, binaryBlocks, + matrixMetadata); + return matrixObject; + } else if (value instanceof Matrix) { + Matrix matrix = (Matrix) value; + MatrixObject matrixObject = matrix.asMatrixObject(); + return matrixObject; + } else if (value instanceof double[][]) { + double[][] doubleMatrix = (double[][]) value; + MatrixObject matrixObject = MLContextConversionUtil.doubleMatrixToMatrixObject(name, doubleMatrix, + matrixMetadata); + return matrixObject; + } else if (value instanceof Integer) { + Integer i = (Integer) value; + IntObject iObj = new IntObject(i); + return iObj; + } else if (value instanceof Double) { + Double d = (Double) value; + DoubleObject dObj = new DoubleObject(d); + return dObj; + } else if (value instanceof String) { + String s = (String) value; + StringObject sObj = new StringObject(s); + return sObj; + } else if (value instanceof Boolean) { + Boolean b = (Boolean) value; + BooleanObject bObj = new BooleanObject(b); + return bObj; + } + return null; + } + + /** + * Return the default matrix block size. + * + * @return the default matrix block size + */ + public static int defaultBlockSize() { + DMLConfig conf = ConfigurationManager.getDMLConfig(); + int blockSize = conf.getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE); + return blockSize; + } + + /** + * Return the location of the scratch space directory. + * + * @return the lcoation of the scratch space directory + */ + public static String scratchSpace() { + DMLConfig conf = ConfigurationManager.getDMLConfig(); + String scratchSpace = conf.getTextValue(DMLConfig.SCRATCH_SPACE); + return scratchSpace; + } + + /** + * Return a double-quoted string with inner single and double quotes + * escaped. + * + * @param str + * the original string + * @return double-quoted string with inner single and double quotes escaped + */ + public static String quotedString(String str) { + if (str == null) { + return null; + } + + StringBuilder sb = new StringBuilder(); + sb.append("\""); + for (int i = 0; i < str.length(); i++) { + char ch = str.charAt(i); + if ((ch == '\'') || (ch == '"')) { + if ((i > 0) && (str.charAt(i - 1) != '\\')) { + sb.append('\\'); + } else if (i == 0) { + sb.append('\\'); + } + } + sb.append(ch); + } + sb.append("\""); + + return sb.toString(); + } + + /** + * Display the keys and values in a Map + * + * @param mapName + * the name of the map + * @param map + * Map of String keys and Object values + * @return the keys and values in the Map as a String + */ + public static String displayMap(String mapName, Map<String, Object> map) { + StringBuilder sb = new StringBuilder(); + sb.append(mapName); + sb.append(":\n"); + Set<String> keys = map.keySet(); + if (keys.isEmpty()) { + sb.append("None\n"); + } else { + int count = 0; + for (String key : keys) { + sb.append(" ["); + sb.append(++count); + sb.append("] "); + sb.append(key); + sb.append(": "); + sb.append(map.get(key)); + sb.append("\n"); + } + } + return sb.toString(); + } + + /** + * Display the values in a Set + * + * @param setName + * the name of the Set + * @param set + * Set of String values + * @return the values in the Set as a String + */ + public static String displaySet(String setName, Set<String> set) { + StringBuilder sb = new StringBuilder(); + sb.append(setName); + sb.append(":\n"); + if (set.isEmpty()) { + sb.append("None\n"); + } else { + int count = 0; + for (String value : set) { + sb.append(" ["); + sb.append(++count); + sb.append("] "); + sb.append(value); + sb.append("\n"); + } + } + return sb.toString(); + } + + /** + * Display the keys and values in the symbol table + * + * @param name + * the name of the symbol table + * @param symbolTable + * the LocalVariableMap + * @return the keys and values in the symbol table as a String + */ + public static String displaySymbolTable(String name, LocalVariableMap symbolTable) { + StringBuilder sb = new StringBuilder(); + sb.append(name); + sb.append(":\n"); + sb.append(displaySymbolTable(symbolTable)); + return sb.toString(); + } + + /** + * Display the keys and values in the symbol table + * + * @param symbolTable + * the LocalVariableMap + * @return the keys and values in the symbol table as a String + */ + public static String displaySymbolTable(LocalVariableMap symbolTable) { + StringBuilder sb = new StringBuilder(); + Set<String> keys = symbolTable.keySet(); + if (keys.isEmpty()) { + sb.append("None\n"); + } else { + int count = 0; + for (String key : keys) { + sb.append(" ["); + sb.append(++count); + sb.append("]"); + + sb.append(" ("); + sb.append(determineOutputTypeAsString(symbolTable, key)); + sb.append(") "); + + sb.append(key); + + sb.append(": "); + sb.append(symbolTable.get(key)); + sb.append("\n"); + } + } + return sb.toString(); + } + + /** + * Obtain a symbol table output type as a String + * + * @param symbolTable + * the symbol table + * @param outputName + * the name of the output variable + * @return the symbol table output type for a variable as a String + */ + public static String determineOutputTypeAsString(LocalVariableMap symbolTable, String outputName) { + Data data = symbolTable.get(outputName); + if (data instanceof BooleanObject) { + return "Boolean"; + } else if (data instanceof DoubleObject) { + return "Double"; + } else if (data instanceof IntObject) { + return "Long"; + } else if (data instanceof StringObject) { + return "String"; + } else if (data instanceof MatrixObject) { + return "Matrix"; + } else if (data instanceof FrameObject) { + return "Frame"; + } + return "Unknown"; + } + + /** + * Obtain a display of script inputs. + * + * @param name + * the title to display for the inputs + * @param map + * the map of inputs + * @return the script inputs represented as a String + */ + public static String displayInputs(String name, Map<String, Object> map) { + StringBuilder sb = new StringBuilder(); + sb.append(name); + sb.append(":\n"); + Set<String> keys = map.keySet(); + if (keys.isEmpty()) { + sb.append("None\n"); + } else { + int count = 0; + for (String key : keys) { + Object object = map.get(key); + @SuppressWarnings("rawtypes") + Class clazz = object.getClass(); + String type = clazz.getSimpleName(); + if (object instanceof JavaRDD<?>) { + type = "JavaRDD"; + } else if (object instanceof RDD<?>) { + type = "RDD"; + } + + sb.append(" ["); + sb.append(++count); + sb.append("]"); + + sb.append(" ("); + sb.append(type); + sb.append(") "); + + sb.append(key); + sb.append(": "); + String str = object.toString(); + str = StringUtils.abbreviate(str, 100); + sb.append(str); + sb.append("\n"); + } + } + return sb.toString(); + } + + /** + * Obtain a display of the script outputs. + * + * @param name + * the title to display for the outputs + * @param outputNames + * the names of the output variables + * @param symbolTable + * the symbol table + * @return the script outputs represented as a String + * + */ + public static String displayOutputs(String name, Set<String> outputNames, LocalVariableMap symbolTable) { + StringBuilder sb = new StringBuilder(); + sb.append(name); + sb.append(":\n"); + sb.append(displayOutputs(outputNames, symbolTable)); + return sb.toString(); + } + + /** + * Obtain a display of the script outputs. + * + * @param outputNames + * the names of the output variables + * @param symbolTable + * the symbol table + * @return the script outputs represented as a String + * + */ + public static String displayOutputs(Set<String> outputNames, LocalVariableMap symbolTable) { + StringBuilder sb = new StringBuilder(); + if (outputNames.isEmpty()) { + sb.append("None\n"); + } else { + int count = 0; + for (String outputName : outputNames) { + sb.append(" ["); + sb.append(++count); + sb.append("] "); + + if (symbolTable.get(outputName) != null) { + sb.append("("); + sb.append(determineOutputTypeAsString(symbolTable, outputName)); + sb.append(") "); + } + + sb.append(outputName); + + if (symbolTable.get(outputName) != null) { + sb.append(": "); + sb.append(symbolTable.get(outputName)); + } + + sb.append("\n"); + } + } + return sb.toString(); + } + + /** + * The SystemML welcome message + * + * @return the SystemML welcome message + */ + public static String welcomeMessage() { + StringBuilder sb = new StringBuilder(); + sb.append("\nWelcome to Apache SystemML!\n"); + return sb.toString(); + } + + /** + * Generate a String history entry for a script. + * + * @param script + * the script + * @param when + * when the script was executed + * @return a script history entry as a String + */ + public static String createHistoryForScript(Script script, long when) { + DateFormat dateFormat = new SimpleDateFormat("MM/dd/yyyy HH:mm:ss.SSS"); + StringBuilder sb = new StringBuilder(); + sb.append("Script Name: " + script.getName() + "\n"); + sb.append("When: " + dateFormat.format(new Date(when)) + "\n"); + sb.append(script.displayInputs()); + sb.append(script.displayOutputs()); + sb.append(script.displaySymbolTable()); + return sb.toString(); + } + + /** + * Generate a String listing of the script execution history. + * + * @param scriptHistory + * the list of script history entries + * @return the listing of the script execution history as a String + */ + public static String displayScriptHistory(List<String> scriptHistory) { + StringBuilder sb = new StringBuilder(); + sb.append("MLContext Script History:\n"); + if (scriptHistory.isEmpty()) { + sb.append("None"); + } + int i = 1; + for (String history : scriptHistory) { + sb.append("--------------------------------------------\n"); + sb.append("#" + (i++) + ":\n"); + sb.append(history); + } + return sb.toString(); + } + +}
